reducing duplication of *EmbeddingStoreIT
This commit is contained in:
parent
7c5cade3c0
commit
e0dc387cef
|
@ -69,6 +69,14 @@
|
||||||
<artifactId>slf4j-api</artifactId>
|
<artifactId>slf4j-api</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-core</artifactId>
|
||||||
|
<classifier>tests</classifier>
|
||||||
|
<type>test-jar</type>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
<artifactId>junit-jupiter</artifactId>
|
<artifactId>junit-jupiter</artifactId>
|
||||||
|
|
|
@ -1,28 +1,19 @@
|
||||||
package dev.langchain4j.store.embedding.opensearch;
|
package dev.langchain4j.store.embedding.opensearch;
|
||||||
|
|
||||||
import dev.langchain4j.data.document.Metadata;
|
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
|
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
|
||||||
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
|
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
|
||||||
|
|
||||||
import java.util.List;
|
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static java.util.Arrays.asList;
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
|
||||||
import static org.assertj.core.data.Percentage.withPercentage;
|
|
||||||
|
|
||||||
@Disabled("Needs OpenSearch running with AWS")
|
@Disabled("Needs OpenSearch running with AWS")
|
||||||
public class OpenSearchEmbeddingStoreAWSTest {
|
public class OpenSearchEmbeddingStoreAWSTest extends EmbeddingStoreIT {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* To run the tests locally, you have to provide an Amazon OpenSearch domain. The code uses
|
* To run the tests locally, you have to provide an Amazon OpenSearch domain. The code uses
|
||||||
|
@ -30,234 +21,36 @@ public class OpenSearchEmbeddingStoreAWSTest {
|
||||||
* your credentials locally, see https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html.
|
* your credentials locally, see https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
private final String domainEndpoint = "your-generated-domain-endpoint-with-no-https.us-east-1.es.amazonaws.com";
|
EmbeddingStore<TextSegment> embeddingStore = OpenSearchEmbeddingStore.builder()
|
||||||
private final ProfileCredentialsProvider credentials = ProfileCredentialsProvider.create("default");
|
.serverUrl("your-generated-domain-endpoint-with-no-https.us-east-1.es.amazonaws.com")
|
||||||
|
|
||||||
private final AwsSdk2TransportOptions transportOptions = AwsSdk2TransportOptions.builder()
|
|
||||||
.setCredentials(credentials)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
private final EmbeddingStore<TextSegment> embeddingStore = OpenSearchEmbeddingStore.builder()
|
|
||||||
.serverUrl(domainEndpoint)
|
|
||||||
.serviceName("es")
|
.serviceName("es")
|
||||||
.region("us-east-1")
|
.region("us-east-1")
|
||||||
.options(transportOptions)
|
.options(AwsSdk2TransportOptions.builder()
|
||||||
|
.setCredentials(ProfileCredentialsProvider.create("default"))
|
||||||
|
.build())
|
||||||
.indexName(randomUUID())
|
.indexName(randomUUID())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding() throws InterruptedException {
|
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||||
|
return embeddingStore;
|
||||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
String id = embeddingStore.add(embedding);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding_with_id() throws InterruptedException {
|
protected EmbeddingModel embeddingModel() {
|
||||||
|
return embeddingModel;
|
||||||
String id = randomUUID();
|
|
||||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
embeddingStore.add(id, embedding);
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding_with_segment() throws InterruptedException {
|
protected void ensureStoreIsEmpty() {
|
||||||
|
// TODO fix
|
||||||
TextSegment segment = TextSegment.from(randomUUID());
|
|
||||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
|
||||||
|
|
||||||
String id = embeddingStore.add(embedding, segment);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isEqualTo(segment);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding_with_segment_with_metadata() throws InterruptedException {
|
@SneakyThrows
|
||||||
|
protected void awaitUntilPersisted() {
|
||||||
TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value"));
|
Thread.sleep(1000);
|
||||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
|
||||||
|
|
||||||
String id = embeddingStore.add(embedding, segment);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isEqualTo(segment);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_add_multiple_embeddings() throws InterruptedException {
|
|
||||||
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
|
|
||||||
List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
|
|
||||||
assertThat(ids).hasSize(2);
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
|
||||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
|
||||||
assertThat(firstMatch.embedded()).isNull();
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
|
||||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
|
||||||
assertThat(secondMatch.embedded()).isNull();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_add_multiple_embeddings_with_segments() throws InterruptedException {
|
|
||||||
|
|
||||||
TextSegment firstSegment = TextSegment.from(randomUUID());
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
|
|
||||||
TextSegment secondSegment = TextSegment.from(randomUUID());
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
|
|
||||||
|
|
||||||
List<String> ids = embeddingStore.addAll(
|
|
||||||
asList(firstEmbedding, secondEmbedding),
|
|
||||||
asList(firstSegment, secondSegment)
|
|
||||||
);
|
|
||||||
assertThat(ids).hasSize(2);
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
|
||||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
|
||||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
|
||||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
|
||||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_find_with_min_score() throws InterruptedException {
|
|
||||||
|
|
||||||
String firstId = randomUUID();
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
embeddingStore.add(firstId, firstEmbedding);
|
|
||||||
|
|
||||||
String secondId = randomUUID();
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
embeddingStore.add(secondId, secondEmbedding);
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score() - 0.01
|
|
||||||
);
|
|
||||||
assertThat(relevant2).hasSize(2);
|
|
||||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score()
|
|
||||||
);
|
|
||||||
assertThat(relevant3).hasSize(2);
|
|
||||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score() + 0.01
|
|
||||||
);
|
|
||||||
assertThat(relevant4).hasSize(1);
|
|
||||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_return_correct_score() throws InterruptedException {
|
|
||||||
|
|
||||||
Embedding embedding = embeddingModel.embed("hello").content();
|
|
||||||
|
|
||||||
String id = embeddingStore.add(embedding);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Embedding referenceEmbedding = embeddingModel.embed("hi").content();
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(
|
|
||||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
|
||||||
withPercentage(1)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static String randomUUID() {
|
|
||||||
return UUID.randomUUID().toString();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,30 +1,21 @@
|
||||||
package dev.langchain4j.store.embedding.opensearch;
|
package dev.langchain4j.store.embedding.opensearch;
|
||||||
|
|
||||||
import dev.langchain4j.data.document.Metadata;
|
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||||
|
import lombok.SneakyThrows;
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.opensearch.testcontainers.OpensearchContainer;
|
import org.opensearch.testcontainers.OpensearchContainer;
|
||||||
import org.testcontainers.junit.jupiter.Container;
|
import org.testcontainers.junit.jupiter.Container;
|
||||||
import org.testcontainers.utility.DockerImageName;
|
import org.testcontainers.utility.DockerImageName;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||||
import static java.util.Arrays.asList;
|
|
||||||
import static org.assertj.core.api.Assertions.assertThat;
|
|
||||||
import static org.assertj.core.data.Percentage.withPercentage;
|
|
||||||
|
|
||||||
@Disabled("Needs OpenSearch running locally")
|
@Disabled("Needs OpenSearch running locally")
|
||||||
class OpenSearchEmbeddingStoreLocalTest {
|
class OpenSearchEmbeddingStoreLocalTest extends EmbeddingStoreIT {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* To run the tests locally, you don't need to have OpenSearch up-and-running. This implementation
|
* To run the tests locally, you don't need to have OpenSearch up-and-running. This implementation
|
||||||
|
@ -33,228 +24,39 @@ class OpenSearchEmbeddingStoreLocalTest {
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@Container
|
@Container
|
||||||
private static final OpensearchContainer opensearch =
|
static OpensearchContainer opensearch =
|
||||||
new OpensearchContainer(DockerImageName.parse("opensearchproject/opensearch:2.10.0"));
|
new OpensearchContainer(DockerImageName.parse("opensearchproject/opensearch:2.10.0"));
|
||||||
|
|
||||||
private final EmbeddingStore<TextSegment> embeddingStore = OpenSearchEmbeddingStore.builder()
|
EmbeddingStore<TextSegment> embeddingStore = OpenSearchEmbeddingStore.builder()
|
||||||
.serverUrl(opensearch.getHttpHostAddress())
|
.serverUrl(opensearch.getHttpHostAddress())
|
||||||
.indexName(randomUUID())
|
.indexName(randomUUID())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||||
|
|
||||||
@BeforeAll
|
@BeforeAll
|
||||||
static void startOpenSearch() {
|
static void startOpenSearch() {
|
||||||
opensearch.start();
|
opensearch.start();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding() throws InterruptedException {
|
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||||
|
return embeddingStore;
|
||||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
String id = embeddingStore.add(embedding);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding_with_id() throws InterruptedException {
|
protected EmbeddingModel embeddingModel() {
|
||||||
|
return embeddingModel;
|
||||||
String id = randomUUID();
|
|
||||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
embeddingStore.add(id, embedding);
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isNull();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding_with_segment() throws InterruptedException {
|
protected void ensureStoreIsEmpty() {
|
||||||
|
// TODO fix
|
||||||
TextSegment segment = TextSegment.from(randomUUID());
|
|
||||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
|
||||||
|
|
||||||
String id = embeddingStore.add(embedding, segment);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isEqualTo(segment);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Override
|
||||||
void should_add_embedding_with_segment_with_metadata() throws InterruptedException {
|
@SneakyThrows
|
||||||
|
protected void awaitUntilPersisted() {
|
||||||
TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value"));
|
Thread.sleep(1000);
|
||||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
|
||||||
|
|
||||||
String id = embeddingStore.add(embedding, segment);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(match.embeddingId()).isEqualTo(id);
|
|
||||||
assertThat(match.embedding()).isEqualTo(embedding);
|
|
||||||
assertThat(match.embedded()).isEqualTo(segment);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_add_multiple_embeddings() throws InterruptedException {
|
|
||||||
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
|
|
||||||
List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
|
|
||||||
assertThat(ids).hasSize(2);
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
|
||||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
|
||||||
assertThat(firstMatch.embedded()).isNull();
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
|
||||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
|
||||||
assertThat(secondMatch.embedded()).isNull();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_add_multiple_embeddings_with_segments() throws InterruptedException {
|
|
||||||
|
|
||||||
TextSegment firstSegment = TextSegment.from(randomUUID());
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
|
|
||||||
TextSegment secondSegment = TextSegment.from(randomUUID());
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
|
|
||||||
|
|
||||||
List<String> ids = embeddingStore.addAll(
|
|
||||||
asList(firstEmbedding, secondEmbedding),
|
|
||||||
asList(firstSegment, secondSegment)
|
|
||||||
);
|
|
||||||
assertThat(ids).hasSize(2);
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
|
||||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
|
||||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
|
||||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
|
||||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_find_with_min_score() throws InterruptedException {
|
|
||||||
|
|
||||||
String firstId = randomUUID();
|
|
||||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
embeddingStore.add(firstId, firstEmbedding);
|
|
||||||
|
|
||||||
String secondId = randomUUID();
|
|
||||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
|
||||||
embeddingStore.add(secondId, secondEmbedding);
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
|
||||||
assertThat(relevant).hasSize(2);
|
|
||||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
|
||||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
|
||||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
|
||||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
|
||||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
|
||||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score() - 0.01
|
|
||||||
);
|
|
||||||
assertThat(relevant2).hasSize(2);
|
|
||||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score()
|
|
||||||
);
|
|
||||||
assertThat(relevant3).hasSize(2);
|
|
||||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.findRelevant(
|
|
||||||
firstEmbedding,
|
|
||||||
10,
|
|
||||||
secondMatch.score() + 0.01
|
|
||||||
);
|
|
||||||
assertThat(relevant4).hasSize(1);
|
|
||||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void should_return_correct_score() throws InterruptedException {
|
|
||||||
|
|
||||||
Embedding embedding = embeddingModel.embed("hello").content();
|
|
||||||
|
|
||||||
String id = embeddingStore.add(embedding);
|
|
||||||
assertThat(id).isNotNull();
|
|
||||||
|
|
||||||
Embedding referenceEmbedding = embeddingModel.embed("hi").content();
|
|
||||||
|
|
||||||
Thread.sleep(2000);
|
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
|
|
||||||
assertThat(relevant).hasSize(1);
|
|
||||||
|
|
||||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
|
||||||
assertThat(match.score()).isCloseTo(
|
|
||||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
|
||||||
withPercentage(1)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue