From 94f62cd538e4c6cc46ee90e7d69fc62289f40fe4 Mon Sep 17 00:00:00 2001 From: deep-learning-dynamo Date: Thu, 28 Sep 2023 18:15:57 +0200 Subject: [PATCH] Chroma: added filtering by score --- langchain4j-chroma/pom.xml | 15 ++ .../chroma/ChromaEmbeddingStore.java | 15 +- .../chroma/ChromaEmbeddingStoreTest.java | 228 ++++++++++++++++-- 3 files changed, 232 insertions(+), 26 deletions(-) diff --git a/langchain4j-chroma/pom.xml b/langchain4j-chroma/pom.xml index 011ee4a2f..22237bc7a 100644 --- a/langchain4j-chroma/pom.xml +++ b/langchain4j-chroma/pom.xml @@ -45,6 +45,21 @@ test + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + \ No newline at end of file diff --git a/langchain4j-chroma/src/main/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStore.java b/langchain4j-chroma/src/main/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStore.java index 617b9f7bb..ba1494597 100644 --- a/langchain4j-chroma/src/main/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStore.java +++ b/langchain4j-chroma/src/main/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStore.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import static dev.langchain4j.internal.Utils.getOrDefault; import static dev.langchain4j.internal.Utils.randomUUID; import static java.time.Duration.ofSeconds; import static java.util.Collections.singletonList; @@ -33,13 +34,11 @@ public class ChromaEmbeddingStore implements EmbeddingStore { * @param timeout The timeout duration for the Chroma client. If not specified, 5 seconds will be used. */ public ChromaEmbeddingStore(String baseUrl, String collectionName, Duration timeout) { - collectionName = collectionName == null ? "default" : collectionName; - timeout = timeout == null ? ofSeconds(5) : timeout; + collectionName = getOrDefault(collectionName, "default"); - this.chromaClient = new ChromaClient(baseUrl, timeout); + this.chromaClient = new ChromaClient(baseUrl, getOrDefault(timeout, ofSeconds(5))); Collection collection = chromaClient.collection(collectionName); - if (collection == null) { Collection createdCollection = chromaClient.createCollection(new CreateCollectionRequest(collectionName)); collectionId = createdCollection.id(); @@ -165,7 +164,11 @@ public class ChromaEmbeddingStore implements EmbeddingStore { QueryResponse queryResponse = chromaClient.queryCollection(collectionId, queryRequest); - return toEmbeddingMatches(queryResponse); + List> matches = toEmbeddingMatches(queryResponse); + + return matches.stream() + .filter(match -> match.score() >= minScore) + .collect(toList()); } private static List> toEmbeddingMatches(QueryResponse queryResponse) { @@ -197,6 +200,6 @@ public class ChromaEmbeddingStore implements EmbeddingStore { private static TextSegment toTextSegment(QueryResponse queryResponse, int i) { String text = queryResponse.documents().get(0).get(i); Map metadata = queryResponse.metadatas().get(0).get(i); - return text == null ? null : TextSegment.from(text, metadata == null ? null : new Metadata(metadata)); + return text == null ? null : TextSegment.from(text, metadata == null ? new Metadata() : new Metadata(metadata)); } } diff --git a/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreTest.java b/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreTest.java index fcb295162..25f709c72 100644 --- a/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreTest.java +++ b/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreTest.java @@ -3,8 +3,11 @@ package dev.langchain4j.store.embedding.chroma; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.CosineSimilarity; import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.RelevanceScore; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -12,35 +15,220 @@ import org.junit.jupiter.api.Test; import java.util.List; import static dev.langchain4j.internal.Utils.randomUUID; +import static java.util.Arrays.asList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.data.Percentage.withPercentage; +@Disabled("needs Chroma running locally") class ChromaEmbeddingStoreTest { + /** + * First ensure you have Chroma running locally. If not, then: + * - Execute "docker pull ghcr.io/chroma-core/chroma:0.4.6" + * - Execute "docker run -d -p 8000:8000 ghcr.io/chroma-core/chroma:0.4.6" + * - Wait until Chroma is ready to serve (may take a few minutes) + */ + + private final EmbeddingStore embeddingStore = ChromaEmbeddingStore.builder() + .baseUrl("http://localhost:8000") + .collectionName(randomUUID()) + .build(); + + private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + @Test - @Disabled("To run this test, you must have a local Chroma instance") - public void testAddEmbeddingAndFindRelevant() { + void should_add_embedding() { - ChromaEmbeddingStore chromaEmbeddingStore = ChromaEmbeddingStore.builder() - .baseUrl("http://localhost:8000") - .collectionName(randomUUID()) - .build(); + Embedding embedding = embeddingModel.embed(randomUUID()).content(); - Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F}); - TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value")); - String id = chromaEmbeddingStore.add(embedding, textSegment); - assertThat(id).isNotBlank(); + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); - Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F}); - List> embeddingMatches = chromaEmbeddingStore.findRelevant(refereceEmbedding, 10); - assertThat(embeddingMatches).hasSize(1); + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); - EmbeddingMatch embeddingMatch = embeddingMatches.get(0); - assertThat(embeddingMatch.score()).isCloseTo( - RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, refereceEmbedding)), - withPercentage(1)); - assertThat(embeddingMatch.embeddingId()).isEqualTo(id); - assertThat(embeddingMatch.embedded()).isEqualTo(textSegment); - assertThat(embeddingMatch.embedding()).isEqualTo(embedding); + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + } + + @Test + void should_add_embedding_with_id() { + + String id = randomUUID(); + Embedding embedding = embeddingModel.embed(randomUUID()).content(); + + embeddingStore.add(id, embedding); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isNull(); + } + + @Test + void should_add_embedding_with_segment() { + + TextSegment segment = TextSegment.from(randomUUID()); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Test + void should_add_embedding_with_segment_with_metadata() { + + TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value")); + Embedding embedding = embeddingModel.embed(segment.text()).content(); + + String id = embeddingStore.add(embedding, segment); + assertThat(id).isNotNull(); + + List> relevant = embeddingStore.findRelevant(embedding, 10); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo(1, withPercentage(1)); + assertThat(match.embeddingId()).isEqualTo(id); + assertThat(match.embedding()).isEqualTo(embedding); + assertThat(match.embedded()).isEqualTo(segment); + } + + @Test + void should_add_multiple_embeddings() { + + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + + List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding)); + assertThat(ids).hasSize(2); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isNull(); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isNull(); + } + + @Test + void should_add_multiple_embeddings_with_segments() { + + TextSegment firstSegment = TextSegment.from(randomUUID()); + Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content(); + TextSegment secondSegment = TextSegment.from(randomUUID()); + Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content(); + + List ids = embeddingStore.addAll( + asList(firstEmbedding, secondEmbedding), + asList(firstSegment, secondSegment) + ); + assertThat(ids).hasSize(2); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0)); + assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding); + assertThat(firstMatch.embedded()).isEqualTo(firstSegment); + + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1)); + assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding); + assertThat(secondMatch.embedded()).isEqualTo(secondSegment); + } + + @Test + void should_find_with_min_score() { + + String firstId = randomUUID(); + Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(firstId, firstEmbedding); + + String secondId = randomUUID(); + Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content(); + embeddingStore.add(secondId, secondEmbedding); + + List> relevant = embeddingStore.findRelevant(firstEmbedding, 10); + assertThat(relevant).hasSize(2); + EmbeddingMatch firstMatch = relevant.get(0); + assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1)); + assertThat(firstMatch.embeddingId()).isEqualTo(firstId); + EmbeddingMatch secondMatch = relevant.get(1); + assertThat(secondMatch.score()).isBetween(0d, 1d); + assertThat(secondMatch.embeddingId()).isEqualTo(secondId); + + List> relevant2 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() - 0.01 + ); + assertThat(relevant2).hasSize(2); + assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant3 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() + ); + assertThat(relevant3).hasSize(2); + assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId); + assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId); + + List> relevant4 = embeddingStore.findRelevant( + firstEmbedding, + 10, + secondMatch.score() + 0.01 + ); + assertThat(relevant4).hasSize(1); + assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId); + } + + @Test + void should_return_correct_score() { + + Embedding embedding = embeddingModel.embed("hello").content(); + + String id = embeddingStore.add(embedding); + assertThat(id).isNotNull(); + + Embedding referenceEmbedding = embeddingModel.embed("hi").content(); + + List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1); + assertThat(relevant).hasSize(1); + + EmbeddingMatch match = relevant.get(0); + assertThat(match.score()).isCloseTo( + RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)), + withPercentage(1) + ); } } \ No newline at end of file