diff --git a/langchain4j-opensearch/pom.xml b/langchain4j-opensearch/pom.xml index 8dfd228bc..a106bece5 100644 --- a/langchain4j-opensearch/pom.xml +++ b/langchain4j-opensearch/pom.xml @@ -69,6 +69,14 @@ slf4j-api + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + org.junit.jupiter junit-jupiter diff --git a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java index 8e709b1a2..d70970563 100644 --- a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java +++ b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java @@ -1,28 +1,19 @@ package dev.langchain4j.store.embedding.opensearch; -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.CosineSimilarity; -import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.RelevanceScore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import lombok.SneakyThrows; import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider; -import java.util.List; -import java.util.UUID; - -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.data.Percentage.withPercentage; +import static dev.langchain4j.internal.Utils.randomUUID; @Disabled("Needs OpenSearch running with AWS") -public class OpenSearchEmbeddingStoreAWSTest { +public class OpenSearchEmbeddingStoreAWSTest extends EmbeddingStoreIT { /** * To run the tests locally, you have to provide an Amazon OpenSearch domain. The code uses @@ -30,234 +21,36 @@ public class OpenSearchEmbeddingStoreAWSTest { * your credentials locally, see https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html. */ - private final String domainEndpoint = "your-generated-domain-endpoint-with-no-https.us-east-1.es.amazonaws.com"; - private final ProfileCredentialsProvider credentials = ProfileCredentialsProvider.create("default"); - - private final AwsSdk2TransportOptions transportOptions = AwsSdk2TransportOptions.builder() - .setCredentials(credentials) - .build(); - - private final EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder() - .serverUrl(domainEndpoint) + EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder() + .serverUrl("your-generated-domain-endpoint-with-no-https.us-east-1.es.amazonaws.com") .serviceName("es") .region("us-east-1") - .options(transportOptions) + .options(AwsSdk2TransportOptions.builder() + .setCredentials(ProfileCredentialsProvider.create("default")) + .build()) .indexName(randomUUID()) .build(); - private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - @Test - void should_add_embedding() throws InterruptedException { - - Embedding embedding = embeddingModel.embed(randomUUID()).content(); - String id = embeddingStore.add(embedding); - assertThat(id).isNotNull(); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isNull(); + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; } - @Test - void should_add_embedding_with_id() throws InterruptedException { - - String id = randomUUID(); - Embedding embedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(id, embedding); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isNull(); + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; } - @Test - void should_add_embedding_with_segment() throws InterruptedException { - - TextSegment segment = TextSegment.from(randomUUID()); - Embedding embedding = embeddingModel.embed(segment.text()).content(); - - String id = embeddingStore.add(embedding, segment); - assertThat(id).isNotNull(); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isEqualTo(segment); + @Override + protected void ensureStoreIsEmpty() { + // TODO fix } - @Test - void should_add_embedding_with_segment_with_metadata() throws InterruptedException { - - TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value")); - Embedding embedding = embeddingModel.embed(segment.text()).content(); - - String id = embeddingStore.add(embedding, segment); - assertThat(id).isNotNull(); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isEqualTo(segment); - } - - @Test - void should_add_multiple_embeddings() throws InterruptedException { - - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - - List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); - assertThat(ids).hasSize(2); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isNull(); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isNull(); - } - - @Test - void should_add_multiple_embeddings_with_segments() throws InterruptedException { - - TextSegment firstSegment = TextSegment.from(randomUUID()); - Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); - TextSegment secondSegment = TextSegment.from(randomUUID()); - Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); - - List ids = embeddingStore.addAll( - asList(firstEmbedding, secondEmbedding), - asList(firstSegment, secondSegment) - ); - assertThat(ids).hasSize(2); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isEqualTo(firstSegment); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isEqualTo(secondSegment); - } - - @Test - void should_find_with_min_score() throws InterruptedException { - - String firstId = randomUUID(); - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(firstId, firstEmbedding); - - String secondId = randomUUID(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(secondId, secondEmbedding); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(firstId); - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(secondId); - - List> relevant2 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() - 0.01 - ); - assertThat(relevant2).hasSize(2); - assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant3 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() - ); - assertThat(relevant3).hasSize(2); - assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant4 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() + 0.01 - ); - assertThat(relevant4).hasSize(1); - assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); - } - - @Test - void should_return_correct_score() throws InterruptedException { - - Embedding embedding = embeddingModel.embed("hello").content(); - - String id = embeddingStore.add(embedding); - assertThat(id).isNotNull(); - - Embedding referenceEmbedding = embeddingModel.embed("hi").content(); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), - withPercentage(1) - ); - } - - public static String randomUUID() { - return UUID.randomUUID().toString(); + @Override + @SneakyThrows + protected void awaitUntilPersisted() { + Thread.sleep(1000); } } diff --git a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java index 4faec4900..662f8f065 100644 --- a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java +++ b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java @@ -1,30 +1,21 @@ package dev.langchain4j.store.embedding.opensearch; -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.CosineSimilarity; -import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.RelevanceScore; +import dev.langchain4j.store.embedding.EmbeddingStoreIT; +import lombok.SneakyThrows; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; import org.opensearch.testcontainers.OpensearchContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.utility.DockerImageName; -import java.util.List; - import static dev.langchain4j.internal.Utils.randomUUID; -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.data.Percentage.withPercentage; @Disabled("Needs OpenSearch running locally") -class OpenSearchEmbeddingStoreLocalTest { +class OpenSearchEmbeddingStoreLocalTest extends EmbeddingStoreIT { /** * To run the tests locally, you don't need to have OpenSearch up-and-running. This implementation @@ -33,228 +24,39 @@ class OpenSearchEmbeddingStoreLocalTest { */ @Container - private static final OpensearchContainer opensearch = + static OpensearchContainer opensearch = new OpensearchContainer(DockerImageName.parse("opensearchproject/opensearch:2.10.0")); - private final EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder() + EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder() .serverUrl(opensearch.getHttpHostAddress()) .indexName(randomUUID()) .build(); - private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @BeforeAll static void startOpenSearch() { opensearch.start(); } - @Test - void should_add_embedding() throws InterruptedException { - - Embedding embedding = embeddingModel.embed(randomUUID()).content(); - String id = embeddingStore.add(embedding); - assertThat(id).isNotNull(); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isNull(); + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; } - @Test - void should_add_embedding_with_id() throws InterruptedException { - - String id = randomUUID(); - Embedding embedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(id, embedding); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isNull(); + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; } - @Test - void should_add_embedding_with_segment() throws InterruptedException { - - TextSegment segment = TextSegment.from(randomUUID()); - Embedding embedding = embeddingModel.embed(segment.text()).content(); - - String id = embeddingStore.add(embedding, segment); - assertThat(id).isNotNull(); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isEqualTo(segment); + @Override + protected void ensureStoreIsEmpty() { + // TODO fix } - @Test - void should_add_embedding_with_segment_with_metadata() throws InterruptedException { - - TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value")); - Embedding embedding = embeddingModel.embed(segment.text()).content(); - - String id = embeddingStore.add(embedding, segment); - assertThat(id).isNotNull(); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isEqualTo(segment); - } - - @Test - void should_add_multiple_embeddings() throws InterruptedException { - - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - - List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); - assertThat(ids).hasSize(2); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isNull(); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isNull(); - } - - @Test - void should_add_multiple_embeddings_with_segments() throws InterruptedException { - - TextSegment firstSegment = TextSegment.from(randomUUID()); - Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); - TextSegment secondSegment = TextSegment.from(randomUUID()); - Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); - - List ids = embeddingStore.addAll( - asList(firstEmbedding, secondEmbedding), - asList(firstSegment, secondSegment) - ); - assertThat(ids).hasSize(2); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isEqualTo(firstSegment); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isEqualTo(secondSegment); - } - - @Test - void should_find_with_min_score() throws InterruptedException { - - String firstId = randomUUID(); - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(firstId, firstEmbedding); - - String secondId = randomUUID(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(secondId, secondEmbedding); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(firstId); - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(secondId); - - List> relevant2 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() - 0.01 - ); - assertThat(relevant2).hasSize(2); - assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant3 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() - ); - assertThat(relevant3).hasSize(2); - assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant4 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() + 0.01 - ); - assertThat(relevant4).hasSize(1); - assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); - } - - @Test - void should_return_correct_score() throws InterruptedException { - - Embedding embedding = embeddingModel.embed("hello").content(); - - String id = embeddingStore.add(embedding); - assertThat(id).isNotNull(); - - Embedding referenceEmbedding = embeddingModel.embed("hi").content(); - - Thread.sleep(2000); - - List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), - withPercentage(1) - ); + @Override + @SneakyThrows + protected void awaitUntilPersisted() { + Thread.sleep(1000); } }