From 903b3fb1ac4c392fa084ea542e3e29117a170106 Mon Sep 17 00:00:00 2001 From: deep-learning-dynamo Date: Sat, 18 Nov 2023 21:28:32 +0100 Subject: [PATCH] reducing duplication of *EmbeddingStoreIT --- langchain4j-weaviate/pom.xml | 8 + .../weaviate/WeaviateEmbeddingStoreIT.java | 195 ++---------------- 2 files changed, 20 insertions(+), 183 deletions(-) diff --git a/langchain4j-weaviate/pom.xml b/langchain4j-weaviate/pom.xml index 474222e5a..3fcf2f510 100644 --- a/langchain4j-weaviate/pom.xml +++ b/langchain4j-weaviate/pom.xml @@ -44,6 +44,14 @@ 1.16.0 + + dev.langchain4j + langchain4j-core + tests + test-jar + test + + org.junit.jupiter junit-jupiter-engine diff --git a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStoreIT.java b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStoreIT.java index 271855c8a..be02e8245 100644 --- a/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStoreIT.java +++ b/langchain4j-weaviate/src/test/java/dev/langchain4j/store/embedding/weaviate/WeaviateEmbeddingStoreIT.java @@ -1,209 +1,38 @@ package dev.langchain4j.store.embedding.weaviate; -import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.store.embedding.CosineSimilarity; -import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingStore; -import dev.langchain4j.store.embedding.RelevanceScore; -import org.junit.jupiter.api.Test; +import dev.langchain4j.store.embedding.EmbeddingStoreWithoutMetadataIT; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; -import java.util.List; - import static dev.langchain4j.internal.Utils.randomUUID; -import static java.util.Arrays.asList; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.data.Percentage.withPercentage; @EnabledIfEnvironmentVariable(named = "WEAVIATE_API_KEY", matches = ".+") -class WeaviateEmbeddingStoreIT { +class WeaviateEmbeddingStoreIT extends EmbeddingStoreWithoutMetadataIT { EmbeddingStore embeddingStore = WeaviateEmbeddingStore.builder() .apiKey(System.getenv("WEAVIATE_API_KEY")) .scheme("https") - .host("test3-bwsieg9y.weaviate.network") + .host("test-am8ocede.weaviate.network") .objectClass("Test" + randomUUID().replace("-", "")) .build(); EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); - @Test - void should_add_embedding() { - - Embedding embedding = embeddingModel.embed(randomUUID()).content(); - - String id = embeddingStore.add(embedding); - assertThat(id).isNotNull(); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isNull(); + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; } - @Test - void should_add_embedding_with_id() { - - String id = randomUUID(); - Embedding embedding = embeddingModel.embed(randomUUID()).content(); - - embeddingStore.add(id, embedding); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isNull(); + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; } - @Test - void should_add_embedding_with_segment() { - - TextSegment segment = TextSegment.from(randomUUID()); - Embedding embedding = embeddingModel.embed(segment.text()).content(); - - String id = embeddingStore.add(embedding, segment); - assertThat(id).isNotNull(); - - List> relevant = embeddingStore.findRelevant(embedding, 10); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo(1, withPercentage(1)); - assertThat(match.embeddingId()).isEqualTo(id); - assertThat(match.embedding()).isEqualTo(embedding); - assertThat(match.embedded()).isEqualTo(segment); - } - - @Test - void should_add_multiple_embeddings() { - - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - - List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); - assertThat(ids).hasSize(2); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isNull(); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isNull(); - } - - @Test - void should_add_multiple_embeddings_with_segments() { - - TextSegment firstSegment = TextSegment.from(randomUUID()); - Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); - TextSegment secondSegment = TextSegment.from(randomUUID()); - Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); - - List ids = embeddingStore.addAll( - asList(firstEmbedding, secondEmbedding), - asList(firstSegment, secondSegment) - ); - assertThat(ids).hasSize(2); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); - assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); - assertThat(firstMatch.embedded()).isEqualTo(firstSegment); - - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); - assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); - assertThat(secondMatch.embedded()).isEqualTo(secondSegment); - } - - @Test - void should_find_with_min_score() { - - String firstId = randomUUID(); - Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(firstId, firstEmbedding); - - String secondId = randomUUID(); - Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); - embeddingStore.add(secondId, secondEmbedding); - - List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); - assertThat(relevant).hasSize(2); - EmbeddingMatch firstMatch = relevant.get(0); - assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); - assertThat(firstMatch.embeddingId()).isEqualTo(firstId); - EmbeddingMatch secondMatch = relevant.get(1); - assertThat(secondMatch.score()).isBetween(0d, 1d); - assertThat(secondMatch.embeddingId()).isEqualTo(secondId); - - List> relevant2 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() - 0.01 - ); - assertThat(relevant2).hasSize(2); - assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant3 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() - ); - assertThat(relevant3).hasSize(2); - assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); - assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); - - List> relevant4 = embeddingStore.findRelevant( - firstEmbedding, - 10, - secondMatch.score() + 0.01 - ); - assertThat(relevant4).hasSize(1); - assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); - } - - @Test - void should_return_correct_score() { - - Embedding embedding = embeddingModel.embed("hello").content(); - - String id = embeddingStore.add(embedding); - assertThat(id).isNotNull(); - - Embedding referenceEmbedding = embeddingModel.embed("hi").content(); - - List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); - assertThat(relevant).hasSize(1); - - EmbeddingMatch match = relevant.get(0); - assertThat(match.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), - withPercentage(1) - ); + @Override + protected void ensureStoreIsEmpty() { + // TODO fix } } \ No newline at end of file