reducing duplication of *EmbeddingStoreIT
This commit is contained in:
parent
e467beb64a
commit
7c5cade3c0
|
@ -3,14 +3,14 @@ package dev.langchain4j.store.embedding.chroma;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.store.embedding.AbstractEmbeddingStoreIT;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
|
||||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||||
|
|
||||||
@Disabled("needs Chroma running locally")
|
@Disabled("needs Chroma running locally")
|
||||||
class ChromaEmbeddingStoreIT extends AbstractEmbeddingStoreIT {
|
class ChromaEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* First ensure you have Chroma running locally. If not, then:
|
* First ensure you have Chroma running locally. If not, then:
|
||||||
|
@ -19,12 +19,12 @@ class ChromaEmbeddingStoreIT extends AbstractEmbeddingStoreIT {
|
||||||
* - Wait until Chroma is ready to serve (may take a few minutes)
|
* - Wait until Chroma is ready to serve (may take a few minutes)
|
||||||
*/
|
*/
|
||||||
|
|
||||||
private final EmbeddingStore<TextSegment> embeddingStore = ChromaEmbeddingStore.builder()
|
EmbeddingStore<TextSegment> embeddingStore = ChromaEmbeddingStore.builder()
|
||||||
.baseUrl("http://localhost:8000")
|
.baseUrl("http://localhost:8000")
|
||||||
.collectionName(randomUUID())
|
.collectionName(randomUUID())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
|
import dev.langchain4j.data.document.Metadata;
|
||||||
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.data.Percentage.withPercentage;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A minimum set of tests that each implementation of {@link EmbeddingStore} must pass.
|
||||||
|
*/
|
||||||
|
public abstract class EmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void should_add_embedding_with_segment_with_metadata() {
|
||||||
|
|
||||||
|
TextSegment segment = TextSegment.from("hello", Metadata.from("test-key", "test-value"));
|
||||||
|
Embedding embedding = embeddingModel().embed(segment.text()).content();
|
||||||
|
|
||||||
|
String id = embeddingStore().add(embedding, segment);
|
||||||
|
assertThat(id).isNotBlank();
|
||||||
|
|
||||||
|
awaitUntilPersisted();
|
||||||
|
|
||||||
|
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
|
||||||
|
assertThat(relevant).hasSize(1);
|
||||||
|
|
||||||
|
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||||
|
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||||
|
assertThat(match.embeddingId()).isEqualTo(id);
|
||||||
|
assertThat(match.embedding()).isEqualTo(embedding);
|
||||||
|
assertThat(match.embedded()).isEqualTo(segment);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,6 +1,5 @@
|
||||||
package dev.langchain4j.store.embedding;
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
import dev.langchain4j.data.document.Metadata;
|
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
@ -14,10 +13,7 @@ import static java.util.Arrays.asList;
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.data.Percentage.withPercentage;
|
import static org.assertj.core.data.Percentage.withPercentage;
|
||||||
|
|
||||||
/**
|
public abstract class EmbeddingStoreWithoutMetadataIT {
|
||||||
* A minimum set of tests that each implementation of {@link EmbeddingStore} must pass.
|
|
||||||
*/
|
|
||||||
public abstract class AbstractEmbeddingStoreIT {
|
|
||||||
|
|
||||||
protected abstract EmbeddingStore<TextSegment> embeddingStore();
|
protected abstract EmbeddingStore<TextSegment> embeddingStore();
|
||||||
|
|
||||||
|
@ -30,7 +26,7 @@ public abstract class AbstractEmbeddingStoreIT {
|
||||||
|
|
||||||
protected void ensureStoreIsEmpty() {
|
protected void ensureStoreIsEmpty() {
|
||||||
Embedding embedding = embeddingModel().embed("hello").content();
|
Embedding embedding = embeddingModel().embed("hello").content();
|
||||||
assertThat(embeddingStore().findRelevant(embedding, Integer.MAX_VALUE)).isEmpty();
|
assertThat(embeddingStore().findRelevant(embedding, 1000)).isEmpty();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -94,27 +90,6 @@ public abstract class AbstractEmbeddingStoreIT {
|
||||||
assertThat(match.embedded()).isEqualTo(segment);
|
assertThat(match.embedded()).isEqualTo(segment);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_add_embedding_with_segment_with_metadata() {
|
|
||||||
|
|
||||||
TextSegment segment = TextSegment.from("hello", Metadata.from("test-key", "test-value"));
|
|
||||||
Embedding embedding = embeddingModel().embed(segment.text()).content();
|
|
||||||
|
|
||||||
String id = embeddingStore().add(embedding, segment);
|
|
||||||
assertThat(id).isNotBlank();
|
|
||||||
|
|
||||||
awaitUntilPersisted();
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore().findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isEqualTo(segment);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_add_multiple_embeddings() {
|
void should_add_multiple_embeddings() {
|
||||||
|
|
|
@ -3,15 +3,15 @@ package dev.langchain4j.store.embedding.elasticsearch;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.store.embedding.AbstractEmbeddingStoreIT;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||||
import lombok.SneakyThrows;
|
import lombok.SneakyThrows;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
|
|
||||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||||
|
|
||||||
@Disabled("needs Elasticsearch to be running locally")
|
@Disabled("needs Elasticsearch to be running locally")
|
||||||
class ElasticsearchEmbeddingStoreIT extends AbstractEmbeddingStoreIT {
|
class ElasticsearchEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* First start elasticsearch locally:
|
* First start elasticsearch locally:
|
||||||
|
@ -19,13 +19,13 @@ class ElasticsearchEmbeddingStoreIT extends AbstractEmbeddingStoreIT {
|
||||||
* docker run -d -p 9200:9200 -p 9300:9300 -e discovery.type=single-node -e xpack.security.enabled=false docker.elastic.co/elasticsearch/elasticsearch:8.9.0
|
* docker run -d -p 9200:9200 -p 9300:9300 -e discovery.type=single-node -e xpack.security.enabled=false docker.elastic.co/elasticsearch/elasticsearch:8.9.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
private final EmbeddingStore<TextSegment> embeddingStore = ElasticsearchEmbeddingStore.builder()
|
EmbeddingStore<TextSegment> embeddingStore = ElasticsearchEmbeddingStore.builder()
|
||||||
.serverUrl("http://localhost:9200")
|
.serverUrl("http://localhost:9200")
|
||||||
.indexName(randomUUID())
|
.indexName(randomUUID())
|
||||||
.dimension(384)
|
.dimension(384)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||||
|
|
|
@ -24,18 +24,13 @@
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>io.milvus</groupId>
|
<groupId>io.milvus</groupId>
|
||||||
<artifactId>milvus-sdk-java</artifactId>
|
<artifactId>milvus-sdk-java</artifactId>
|
||||||
<version>2.3.1</version>
|
<version>2.3.3</version>
|
||||||
<exclusions>
|
<exclusions>
|
||||||
<!-- due to CVE-2022-41915 vulnerability -->
|
<!-- due to CVE-2022-41915 vulnerability -->
|
||||||
<exclusion>
|
<exclusion>
|
||||||
<groupId>io.netty</groupId>
|
<groupId>io.netty</groupId>
|
||||||
<artifactId>netty-codec</artifactId>
|
<artifactId>netty-codec</artifactId>
|
||||||
</exclusion>
|
</exclusion>
|
||||||
<!-- due to CVE-2022-42889 vulnerability -->
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-text</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
|
@ -43,10 +38,13 @@
|
||||||
<artifactId>netty-codec</artifactId>
|
<artifactId>netty-codec</artifactId>
|
||||||
<version>${netty.version}</version>
|
<version>${netty.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.commons</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>commons-text</artifactId>
|
<artifactId>langchain4j-core</artifactId>
|
||||||
<version>1.10.0</version>
|
<classifier>tests</classifier>
|
||||||
|
<type>test-jar</type>
|
||||||
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|
|
@ -6,6 +6,7 @@ import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||||
import io.milvus.client.MilvusServiceClient;
|
import io.milvus.client.MilvusServiceClient;
|
||||||
import io.milvus.common.clientenum.ConsistencyLevelEnum;
|
import io.milvus.common.clientenum.ConsistencyLevelEnum;
|
||||||
|
import io.milvus.exception.ParamException;
|
||||||
import io.milvus.response.QueryResultsWrapper;
|
import io.milvus.response.QueryResultsWrapper;
|
||||||
import io.milvus.response.SearchResultsWrapper;
|
import io.milvus.response.SearchResultsWrapper;
|
||||||
|
|
||||||
|
@ -47,10 +48,15 @@ class Mapper {
|
||||||
boolean queryForVectorOnSearch) {
|
boolean queryForVectorOnSearch) {
|
||||||
List<EmbeddingMatch<TextSegment>> matches = new ArrayList<>();
|
List<EmbeddingMatch<TextSegment>> matches = new ArrayList<>();
|
||||||
|
|
||||||
List<String> rowIds = (List<String>) resultsWrapper.getFieldWrapper(ID_FIELD_NAME).getFieldData();
|
|
||||||
Map<String, Embedding> idToEmbedding = new HashMap<>();
|
Map<String, Embedding> idToEmbedding = new HashMap<>();
|
||||||
if (queryForVectorOnSearch) {
|
if (queryForVectorOnSearch) {
|
||||||
|
try {
|
||||||
|
List<String> rowIds = (List<String>) resultsWrapper.getFieldWrapper(ID_FIELD_NAME).getFieldData();
|
||||||
idToEmbedding.putAll(queryEmbeddings(milvusClient, collectionName, rowIds, consistencyLevel));
|
idToEmbedding.putAll(queryEmbeddings(milvusClient, collectionName, rowIds, consistencyLevel));
|
||||||
|
} catch (ParamException e) {
|
||||||
|
// There is no way to check if the result is empty or not.
|
||||||
|
// If the result is empty, the exception will be thrown.
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < resultsWrapper.getRowRecords().size(); i++) {
|
for (int i = 0; i < resultsWrapper.getRowRecords().size(); i++) {
|
||||||
|
|
|
@ -131,7 +131,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||||
retrieveEmbeddingsOnSearch
|
retrieveEmbeddingsOnSearch
|
||||||
);
|
);
|
||||||
|
|
||||||
return matches.stream().filter(match -> match.score() >= minScore).collect(toList());
|
return matches.stream()
|
||||||
|
.filter(match -> match.score() >= minScore)
|
||||||
|
.collect(toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
|
private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
|
||||||
|
|
|
@ -4,12 +4,12 @@ import dev.langchain4j.data.embedding.Embedding;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
||||||
import static org.assertj.core.data.Percentage.withPercentage;
|
import static org.assertj.core.data.Percentage.withPercentage;
|
||||||
|
|
||||||
@Disabled("needs Milvus running locally")
|
@Disabled("needs Milvus running locally")
|
||||||
class MilvusEmbeddingStoreIT {
|
class MilvusEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* First run Milvus locally:
|
* First run Milvus locally:
|
||||||
|
@ -33,221 +33,44 @@ class MilvusEmbeddingStoreIT {
|
||||||
.port(19530)
|
.port(19530)
|
||||||
.collectionName("collection_" + randomUUID().replace("-", ""))
|
.collectionName("collection_" + randomUUID().replace("-", ""))
|
||||||
.dimension(384)
|
.dimension(384)
|
||||||
|
.retrieveEmbeddingsOnSearch(true)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding() {
|
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||||
|
return embeddingStore;
|
||||||
|
}
|
||||||
|
|
||||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
@Override
|
||||||
|
protected EmbeddingModel embeddingModel() {
|
||||||
String id = embeddingStore.add(embedding);
|
return embeddingModel;
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isNull();
|
|
||||||
assertThat(match.embedded()).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void should_add_embedding_with_id() {
|
void should_not_retrieve_embeddings_when_searching() {
|
||||||
|
|
||||||
String id = randomUUID();
|
|
||||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
|
|
||||||
embeddingStore.add(id, embedding);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isNull();
|
|
||||||
assertThat(match.embedded()).isNull();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_add_embedding_with_segment() {
|
|
||||||
|
|
||||||
TextSegment segment = TextSegment.from(randomUUID());
|
|
||||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
|
||||||
|
|
||||||
String id = embeddingStore.add(embedding, segment);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isNull();
|
|
||||||
assertThat(match.embedded()).isEqualTo(segment);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_add_multiple_embeddings() {
|
|
||||||
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
|
|
||||||
List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
|
|
||||||
assertThat(ids).hasSize(2);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
|
||||||
assertThat(firstMatch.embedding()).isNull();
|
|
||||||
assertThat(firstMatch.embedded()).isNull();
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
|
||||||
assertThat(secondMatch.embedding()).isNull();
|
|
||||||
assertThat(secondMatch.embedded()).isNull();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_add_multiple_embeddings_with_segments() {
|
|
||||||
|
|
||||||
TextSegment firstSegment = TextSegment.from(randomUUID());
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
|
|
||||||
TextSegment secondSegment = TextSegment.from(randomUUID());
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
|
|
||||||
|
|
||||||
List<String> ids = embeddingStore.addAll(
|
|
||||||
asList(firstEmbedding, secondEmbedding),
|
|
||||||
asList(firstSegment, secondSegment)
|
|
||||||
);
|
|
||||||
assertThat(ids).hasSize(2);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
|
||||||
assertThat(firstMatch.embedding()).isNull();
|
|
||||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
|
||||||
assertThat(secondMatch.embedding()).isNull();
|
|
||||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_find_with_min_score() {
|
|
||||||
|
|
||||||
String firstId = randomUUID();
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
embeddingStore.add(firstId, firstEmbedding);
|
|
||||||
|
|
||||||
String secondId = randomUUID();
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
embeddingStore.add(secondId, secondEmbedding);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score() - 0.01
|
|
||||||
);
|
|
||||||
assertThat(relevant2).hasSize(2);
|
|
||||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score()
|
|
||||||
);
|
|
||||||
assertThat(relevant3).hasSize(2);
|
|
||||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score() + 0.01
|
|
||||||
);
|
|
||||||
assertThat(relevant4).hasSize(1);
|
|
||||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_return_correct_score() {
|
|
||||||
|
|
||||||
Embedding embedding = embeddingModel.embed("hello").content();
|
|
||||||
|
|
||||||
String id = embeddingStore.add(embedding);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Embedding referenceEmbedding = embeddingModel.embed("hi").content();
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(
|
|
||||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
|
||||||
withPercentage(1)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_retrieve_embeddings_when_searching() {
|
|
||||||
|
|
||||||
EmbeddingStore<TextSegment> embeddingStore = MilvusEmbeddingStore.builder()
|
EmbeddingStore<TextSegment> embeddingStore = MilvusEmbeddingStore.builder()
|
||||||
.host("localhost")
|
.host("localhost")
|
||||||
.port(19530)
|
.port(19530)
|
||||||
.collectionName("collection_" + randomUUID().replace("-", ""))
|
.collectionName("collection_" + randomUUID().replace("-", ""))
|
||||||
.dimension(384)
|
.dimension(384)
|
||||||
.retrieveEmbeddingsOnSearch(true)
|
.retrieveEmbeddingsOnSearch(false)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
Embedding firstEmbedding = embeddingModel.embed("hello").content();
|
||||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
Embedding secondEmbedding = embeddingModel.embed("hi").content();
|
||||||
|
embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
|
||||||
List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
|
|
||||||
assertThat(ids).hasSize(2);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||||
assertThat(relevant).hasSize(2);
|
assertThat(relevant).hasSize(2);
|
||||||
|
assertThat(relevant.get(0).embedding()).isNull();
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
assertThat(relevant.get(1).embedding()).isNull();
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
|
||||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
|
||||||
assertThat(firstMatch.embedded()).isNull();
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
|
||||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
|
||||||
assertThat(secondMatch.embedded()).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@EnabledIfEnvironmentVariable(named = "MILVUS_API_KEY", matches = ".+")
|
||||||
void should_use_cloud_instance() {
|
void should_use_cloud_instance() {
|
||||||
|
|
||||||
EmbeddingStore<TextSegment> embeddingStore = MilvusEmbeddingStore.builder()
|
EmbeddingStore<TextSegment> embeddingStore = MilvusEmbeddingStore.builder()
|
||||||
|
|
Loading…
Reference in New Issue