diff --git a/langchain4j-opensearch/pom.xml b/langchain4j-opensearch/pom.xml
index 8dfd228bc..a106bece5 100644
--- a/langchain4j-opensearch/pom.xml
+++ b/langchain4j-opensearch/pom.xml
@@ -69,6 +69,14 @@
slf4j-api
+
+ dev.langchain4j
+ langchain4j-core
+ tests
+ test-jar
+ test
+
+
org.junit.jupiter
junit-jupiter
diff --git a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java
index 8e709b1a2..d70970563 100644
--- a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java
+++ b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreAWSTest.java
@@ -1,28 +1,19 @@
package dev.langchain4j.store.embedding.opensearch;
-import dev.langchain4j.data.document.Metadata;
-import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
-import dev.langchain4j.store.embedding.CosineSimilarity;
-import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
-import dev.langchain4j.store.embedding.RelevanceScore;
+import dev.langchain4j.store.embedding.EmbeddingStoreIT;
+import lombok.SneakyThrows;
import org.junit.jupiter.api.Disabled;
-import org.junit.jupiter.api.Test;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
-import java.util.List;
-import java.util.UUID;
-
-import static java.util.Arrays.asList;
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.data.Percentage.withPercentage;
+import static dev.langchain4j.internal.Utils.randomUUID;
@Disabled("Needs OpenSearch running with AWS")
-public class OpenSearchEmbeddingStoreAWSTest {
+public class OpenSearchEmbeddingStoreAWSTest extends EmbeddingStoreIT {
/**
* To run the tests locally, you have to provide an Amazon OpenSearch domain. The code uses
@@ -30,234 +21,36 @@ public class OpenSearchEmbeddingStoreAWSTest {
* your credentials locally, see https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html.
*/
- private final String domainEndpoint = "your-generated-domain-endpoint-with-no-https.us-east-1.es.amazonaws.com";
- private final ProfileCredentialsProvider credentials = ProfileCredentialsProvider.create("default");
-
- private final AwsSdk2TransportOptions transportOptions = AwsSdk2TransportOptions.builder()
- .setCredentials(credentials)
- .build();
-
- private final EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder()
- .serverUrl(domainEndpoint)
+ EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder()
+ .serverUrl("your-generated-domain-endpoint-with-no-https.us-east-1.es.amazonaws.com")
.serviceName("es")
.region("us-east-1")
- .options(transportOptions)
+ .options(AwsSdk2TransportOptions.builder()
+ .setCredentials(ProfileCredentialsProvider.create("default"))
+ .build())
.indexName(randomUUID())
.build();
- private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
+ EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
- @Test
- void should_add_embedding() throws InterruptedException {
-
- Embedding embedding = embeddingModel.embed(randomUUID()).content();
- String id = embeddingStore.add(embedding);
- assertThat(id).isNotNull();
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(embedding, 10);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(1, withPercentage(1));
- assertThat(match.embeddingId()).isEqualTo(id);
- assertThat(match.embedding()).isEqualTo(embedding);
- assertThat(match.embedded()).isNull();
+ @Override
+ protected EmbeddingStore embeddingStore() {
+ return embeddingStore;
}
- @Test
- void should_add_embedding_with_id() throws InterruptedException {
-
- String id = randomUUID();
- Embedding embedding = embeddingModel.embed(randomUUID()).content();
- embeddingStore.add(id, embedding);
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(embedding, 10);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(1, withPercentage(1));
- assertThat(match.embeddingId()).isEqualTo(id);
- assertThat(match.embedding()).isEqualTo(embedding);
- assertThat(match.embedded()).isNull();
+ @Override
+ protected EmbeddingModel embeddingModel() {
+ return embeddingModel;
}
- @Test
- void should_add_embedding_with_segment() throws InterruptedException {
-
- TextSegment segment = TextSegment.from(randomUUID());
- Embedding embedding = embeddingModel.embed(segment.text()).content();
-
- String id = embeddingStore.add(embedding, segment);
- assertThat(id).isNotNull();
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(embedding, 10);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(1, withPercentage(1));
- assertThat(match.embeddingId()).isEqualTo(id);
- assertThat(match.embedding()).isEqualTo(embedding);
- assertThat(match.embedded()).isEqualTo(segment);
+ @Override
+ protected void ensureStoreIsEmpty() {
+ // TODO fix
}
- @Test
- void should_add_embedding_with_segment_with_metadata() throws InterruptedException {
-
- TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value"));
- Embedding embedding = embeddingModel.embed(segment.text()).content();
-
- String id = embeddingStore.add(embedding, segment);
- assertThat(id).isNotNull();
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(embedding, 10);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(1, withPercentage(1));
- assertThat(match.embeddingId()).isEqualTo(id);
- assertThat(match.embedding()).isEqualTo(embedding);
- assertThat(match.embedded()).isEqualTo(segment);
- }
-
- @Test
- void should_add_multiple_embeddings() throws InterruptedException {
-
- Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
- Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
-
- List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
- assertThat(ids).hasSize(2);
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
- assertThat(relevant).hasSize(2);
-
- EmbeddingMatch firstMatch = relevant.get(0);
- assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
- assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
- assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
- assertThat(firstMatch.embedded()).isNull();
-
- EmbeddingMatch secondMatch = relevant.get(1);
- assertThat(secondMatch.score()).isBetween(0d, 1d);
- assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
- assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
- assertThat(secondMatch.embedded()).isNull();
- }
-
- @Test
- void should_add_multiple_embeddings_with_segments() throws InterruptedException {
-
- TextSegment firstSegment = TextSegment.from(randomUUID());
- Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
- TextSegment secondSegment = TextSegment.from(randomUUID());
- Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
-
- List ids = embeddingStore.addAll(
- asList(firstEmbedding, secondEmbedding),
- asList(firstSegment, secondSegment)
- );
- assertThat(ids).hasSize(2);
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
- assertThat(relevant).hasSize(2);
-
- EmbeddingMatch firstMatch = relevant.get(0);
- assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
- assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
- assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
- assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
-
- EmbeddingMatch secondMatch = relevant.get(1);
- assertThat(secondMatch.score()).isBetween(0d, 1d);
- assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
- assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
- assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
- }
-
- @Test
- void should_find_with_min_score() throws InterruptedException {
-
- String firstId = randomUUID();
- Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
- embeddingStore.add(firstId, firstEmbedding);
-
- String secondId = randomUUID();
- Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
- embeddingStore.add(secondId, secondEmbedding);
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
- assertThat(relevant).hasSize(2);
- EmbeddingMatch firstMatch = relevant.get(0);
- assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
- assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
- EmbeddingMatch secondMatch = relevant.get(1);
- assertThat(secondMatch.score()).isBetween(0d, 1d);
- assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
-
- List> relevant2 = embeddingStore.findRelevant(
- firstEmbedding,
- 10,
- secondMatch.score() - 0.01
- );
- assertThat(relevant2).hasSize(2);
- assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
- assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
-
- List> relevant3 = embeddingStore.findRelevant(
- firstEmbedding,
- 10,
- secondMatch.score()
- );
- assertThat(relevant3).hasSize(2);
- assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
- assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
-
- List> relevant4 = embeddingStore.findRelevant(
- firstEmbedding,
- 10,
- secondMatch.score() + 0.01
- );
- assertThat(relevant4).hasSize(1);
- assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
- }
-
- @Test
- void should_return_correct_score() throws InterruptedException {
-
- Embedding embedding = embeddingModel.embed("hello").content();
-
- String id = embeddingStore.add(embedding);
- assertThat(id).isNotNull();
-
- Embedding referenceEmbedding = embeddingModel.embed("hi").content();
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(
- RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
- withPercentage(1)
- );
- }
-
- public static String randomUUID() {
- return UUID.randomUUID().toString();
+ @Override
+ @SneakyThrows
+ protected void awaitUntilPersisted() {
+ Thread.sleep(1000);
}
}
diff --git a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java
index 4faec4900..662f8f065 100644
--- a/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java
+++ b/langchain4j-opensearch/src/test/java/dev/langchain4j/store/embedding/opensearch/OpenSearchEmbeddingStoreLocalTest.java
@@ -1,30 +1,21 @@
package dev.langchain4j.store.embedding.opensearch;
-import dev.langchain4j.data.document.Metadata;
-import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
-import dev.langchain4j.store.embedding.CosineSimilarity;
-import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
-import dev.langchain4j.store.embedding.RelevanceScore;
+import dev.langchain4j.store.embedding.EmbeddingStoreIT;
+import lombok.SneakyThrows;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
-import org.junit.jupiter.api.Test;
import org.opensearch.testcontainers.OpensearchContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.utility.DockerImageName;
-import java.util.List;
-
import static dev.langchain4j.internal.Utils.randomUUID;
-import static java.util.Arrays.asList;
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.data.Percentage.withPercentage;
@Disabled("Needs OpenSearch running locally")
-class OpenSearchEmbeddingStoreLocalTest {
+class OpenSearchEmbeddingStoreLocalTest extends EmbeddingStoreIT {
/**
* To run the tests locally, you don't need to have OpenSearch up-and-running. This implementation
@@ -33,228 +24,39 @@ class OpenSearchEmbeddingStoreLocalTest {
*/
@Container
- private static final OpensearchContainer opensearch =
+ static OpensearchContainer opensearch =
new OpensearchContainer(DockerImageName.parse("opensearchproject/opensearch:2.10.0"));
- private final EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder()
+ EmbeddingStore embeddingStore = OpenSearchEmbeddingStore.builder()
.serverUrl(opensearch.getHttpHostAddress())
.indexName(randomUUID())
.build();
- private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
+ EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
@BeforeAll
static void startOpenSearch() {
opensearch.start();
}
- @Test
- void should_add_embedding() throws InterruptedException {
-
- Embedding embedding = embeddingModel.embed(randomUUID()).content();
- String id = embeddingStore.add(embedding);
- assertThat(id).isNotNull();
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(embedding, 10);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(1, withPercentage(1));
- assertThat(match.embeddingId()).isEqualTo(id);
- assertThat(match.embedding()).isEqualTo(embedding);
- assertThat(match.embedded()).isNull();
+ @Override
+ protected EmbeddingStore embeddingStore() {
+ return embeddingStore;
}
- @Test
- void should_add_embedding_with_id() throws InterruptedException {
-
- String id = randomUUID();
- Embedding embedding = embeddingModel.embed(randomUUID()).content();
- embeddingStore.add(id, embedding);
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(embedding, 10);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(1, withPercentage(1));
- assertThat(match.embeddingId()).isEqualTo(id);
- assertThat(match.embedding()).isEqualTo(embedding);
- assertThat(match.embedded()).isNull();
+ @Override
+ protected EmbeddingModel embeddingModel() {
+ return embeddingModel;
}
- @Test
- void should_add_embedding_with_segment() throws InterruptedException {
-
- TextSegment segment = TextSegment.from(randomUUID());
- Embedding embedding = embeddingModel.embed(segment.text()).content();
-
- String id = embeddingStore.add(embedding, segment);
- assertThat(id).isNotNull();
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(embedding, 10);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(1, withPercentage(1));
- assertThat(match.embeddingId()).isEqualTo(id);
- assertThat(match.embedding()).isEqualTo(embedding);
- assertThat(match.embedded()).isEqualTo(segment);
+ @Override
+ protected void ensureStoreIsEmpty() {
+ // TODO fix
}
- @Test
- void should_add_embedding_with_segment_with_metadata() throws InterruptedException {
-
- TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value"));
- Embedding embedding = embeddingModel.embed(segment.text()).content();
-
- String id = embeddingStore.add(embedding, segment);
- assertThat(id).isNotNull();
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(embedding, 10);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(1, withPercentage(1));
- assertThat(match.embeddingId()).isEqualTo(id);
- assertThat(match.embedding()).isEqualTo(embedding);
- assertThat(match.embedded()).isEqualTo(segment);
- }
-
- @Test
- void should_add_multiple_embeddings() throws InterruptedException {
-
- Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
- Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
-
- List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
- assertThat(ids).hasSize(2);
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
- assertThat(relevant).hasSize(2);
-
- EmbeddingMatch firstMatch = relevant.get(0);
- assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
- assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
- assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
- assertThat(firstMatch.embedded()).isNull();
-
- EmbeddingMatch secondMatch = relevant.get(1);
- assertThat(secondMatch.score()).isBetween(0d, 1d);
- assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
- assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
- assertThat(secondMatch.embedded()).isNull();
- }
-
- @Test
- void should_add_multiple_embeddings_with_segments() throws InterruptedException {
-
- TextSegment firstSegment = TextSegment.from(randomUUID());
- Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
- TextSegment secondSegment = TextSegment.from(randomUUID());
- Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
-
- List ids = embeddingStore.addAll(
- asList(firstEmbedding, secondEmbedding),
- asList(firstSegment, secondSegment)
- );
- assertThat(ids).hasSize(2);
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
- assertThat(relevant).hasSize(2);
-
- EmbeddingMatch firstMatch = relevant.get(0);
- assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
- assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
- assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
- assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
-
- EmbeddingMatch secondMatch = relevant.get(1);
- assertThat(secondMatch.score()).isBetween(0d, 1d);
- assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
- assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
- assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
- }
-
- @Test
- void should_find_with_min_score() throws InterruptedException {
-
- String firstId = randomUUID();
- Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
- embeddingStore.add(firstId, firstEmbedding);
-
- String secondId = randomUUID();
- Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
- embeddingStore.add(secondId, secondEmbedding);
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
- assertThat(relevant).hasSize(2);
- EmbeddingMatch firstMatch = relevant.get(0);
- assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
- assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
- EmbeddingMatch secondMatch = relevant.get(1);
- assertThat(secondMatch.score()).isBetween(0d, 1d);
- assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
-
- List> relevant2 = embeddingStore.findRelevant(
- firstEmbedding,
- 10,
- secondMatch.score() - 0.01
- );
- assertThat(relevant2).hasSize(2);
- assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
- assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
-
- List> relevant3 = embeddingStore.findRelevant(
- firstEmbedding,
- 10,
- secondMatch.score()
- );
- assertThat(relevant3).hasSize(2);
- assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
- assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
-
- List> relevant4 = embeddingStore.findRelevant(
- firstEmbedding,
- 10,
- secondMatch.score() + 0.01
- );
- assertThat(relevant4).hasSize(1);
- assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
- }
-
- @Test
- void should_return_correct_score() throws InterruptedException {
-
- Embedding embedding = embeddingModel.embed("hello").content();
-
- String id = embeddingStore.add(embedding);
- assertThat(id).isNotNull();
-
- Embedding referenceEmbedding = embeddingModel.embed("hi").content();
-
- Thread.sleep(2000);
-
- List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
- assertThat(relevant).hasSize(1);
-
- EmbeddingMatch match = relevant.get(0);
- assertThat(match.score()).isCloseTo(
- RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
- withPercentage(1)
- );
+ @Override
+ @SneakyThrows
+ protected void awaitUntilPersisted() {
+ Thread.sleep(1000);
}
}