diff --git a/langchain4j-chroma/pom.xml b/langchain4j-chroma/pom.xml
index 011ee4a2f..22237bc7a 100644
--- a/langchain4j-chroma/pom.xml
+++ b/langchain4j-chroma/pom.xml
@@ -45,6 +45,21 @@
test
+
+ dev.langchain4j
+ langchain4j-embeddings-all-minilm-l6-v2-q
+ test
+
+
+
+
+ Apache-2.0
+ https://www.apache.org/licenses/LICENSE-2.0.txt
+ repo
+ A business-friendly OSS license
+
+
+
\ No newline at end of file
diff --git a/langchain4j-chroma/src/main/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStore.java b/langchain4j-chroma/src/main/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStore.java
index 617b9f7bb..ba1494597 100644
--- a/langchain4j-chroma/src/main/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStore.java
+++ b/langchain4j-chroma/src/main/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStore.java
@@ -11,6 +11,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.randomUUID;
import static java.time.Duration.ofSeconds;
import static java.util.Collections.singletonList;
@@ -33,13 +34,11 @@ public class ChromaEmbeddingStore implements EmbeddingStore {
* @param timeout The timeout duration for the Chroma client. If not specified, 5 seconds will be used.
*/
public ChromaEmbeddingStore(String baseUrl, String collectionName, Duration timeout) {
- collectionName = collectionName == null ? "default" : collectionName;
- timeout = timeout == null ? ofSeconds(5) : timeout;
+ collectionName = getOrDefault(collectionName, "default");
- this.chromaClient = new ChromaClient(baseUrl, timeout);
+ this.chromaClient = new ChromaClient(baseUrl, getOrDefault(timeout, ofSeconds(5)));
Collection collection = chromaClient.collection(collectionName);
-
if (collection == null) {
Collection createdCollection = chromaClient.createCollection(new CreateCollectionRequest(collectionName));
collectionId = createdCollection.id();
@@ -165,7 +164,11 @@ public class ChromaEmbeddingStore implements EmbeddingStore {
QueryResponse queryResponse = chromaClient.queryCollection(collectionId, queryRequest);
- return toEmbeddingMatches(queryResponse);
+ List> matches = toEmbeddingMatches(queryResponse);
+
+ return matches.stream()
+ .filter(match -> match.score() >= minScore)
+ .collect(toList());
}
private static List> toEmbeddingMatches(QueryResponse queryResponse) {
@@ -197,6 +200,6 @@ public class ChromaEmbeddingStore implements EmbeddingStore {
private static TextSegment toTextSegment(QueryResponse queryResponse, int i) {
String text = queryResponse.documents().get(0).get(i);
Map metadata = queryResponse.metadatas().get(0).get(i);
- return text == null ? null : TextSegment.from(text, metadata == null ? null : new Metadata(metadata));
+ return text == null ? null : TextSegment.from(text, metadata == null ? new Metadata() : new Metadata(metadata));
}
}
diff --git a/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreTest.java b/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreTest.java
index fcb295162..25f709c72 100644
--- a/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreTest.java
+++ b/langchain4j-chroma/src/test/java/dev/langchain4j/store/embedding/chroma/ChromaEmbeddingStoreTest.java
@@ -3,8 +3,11 @@ package dev.langchain4j.store.embedding.chroma;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
+import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
@@ -12,35 +15,220 @@ import org.junit.jupiter.api.Test;
import java.util.List;
import static dev.langchain4j.internal.Utils.randomUUID;
+import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
+@Disabled("needs Chroma running locally")
class ChromaEmbeddingStoreTest {
+ /**
+ * First ensure you have Chroma running locally. If not, then:
+ * - Execute "docker pull ghcr.io/chroma-core/chroma:0.4.6"
+ * - Execute "docker run -d -p 8000:8000 ghcr.io/chroma-core/chroma:0.4.6"
+ * - Wait until Chroma is ready to serve (may take a few minutes)
+ */
+
+ private final EmbeddingStore embeddingStore = ChromaEmbeddingStore.builder()
+ .baseUrl("http://localhost:8000")
+ .collectionName(randomUUID())
+ .build();
+
+ private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
+
@Test
- @Disabled("To run this test, you must have a local Chroma instance")
- public void testAddEmbeddingAndFindRelevant() {
+ void should_add_embedding() {
- ChromaEmbeddingStore chromaEmbeddingStore = ChromaEmbeddingStore.builder()
- .baseUrl("http://localhost:8000")
- .collectionName(randomUUID())
- .build();
+ Embedding embedding = embeddingModel.embed(randomUUID()).content();
- Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
- TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
- String id = chromaEmbeddingStore.add(embedding, textSegment);
- assertThat(id).isNotBlank();
+ String id = embeddingStore.add(embedding);
+ assertThat(id).isNotNull();
- Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
- List> embeddingMatches = chromaEmbeddingStore.findRelevant(refereceEmbedding, 10);
- assertThat(embeddingMatches).hasSize(1);
+ List> relevant = embeddingStore.findRelevant(embedding, 10);
+ assertThat(relevant).hasSize(1);
- EmbeddingMatch embeddingMatch = embeddingMatches.get(0);
- assertThat(embeddingMatch.score()).isCloseTo(
- RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, refereceEmbedding)),
- withPercentage(1));
- assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
- assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
- assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(1, withPercentage(1));
+ assertThat(match.embeddingId()).isEqualTo(id);
+ assertThat(match.embedding()).isEqualTo(embedding);
+ assertThat(match.embedded()).isNull();
+ }
+
+ @Test
+ void should_add_embedding_with_id() {
+
+ String id = randomUUID();
+ Embedding embedding = embeddingModel.embed(randomUUID()).content();
+
+ embeddingStore.add(id, embedding);
+
+ List> relevant = embeddingStore.findRelevant(embedding, 10);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(1, withPercentage(1));
+ assertThat(match.embeddingId()).isEqualTo(id);
+ assertThat(match.embedding()).isEqualTo(embedding);
+ assertThat(match.embedded()).isNull();
+ }
+
+ @Test
+ void should_add_embedding_with_segment() {
+
+ TextSegment segment = TextSegment.from(randomUUID());
+ Embedding embedding = embeddingModel.embed(segment.text()).content();
+
+ String id = embeddingStore.add(embedding, segment);
+ assertThat(id).isNotNull();
+
+ List> relevant = embeddingStore.findRelevant(embedding, 10);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(1, withPercentage(1));
+ assertThat(match.embeddingId()).isEqualTo(id);
+ assertThat(match.embedding()).isEqualTo(embedding);
+ assertThat(match.embedded()).isEqualTo(segment);
+ }
+
+ @Test
+ void should_add_embedding_with_segment_with_metadata() {
+
+ TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value"));
+ Embedding embedding = embeddingModel.embed(segment.text()).content();
+
+ String id = embeddingStore.add(embedding, segment);
+ assertThat(id).isNotNull();
+
+ List> relevant = embeddingStore.findRelevant(embedding, 10);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(1, withPercentage(1));
+ assertThat(match.embeddingId()).isEqualTo(id);
+ assertThat(match.embedding()).isEqualTo(embedding);
+ assertThat(match.embedded()).isEqualTo(segment);
+ }
+
+ @Test
+ void should_add_multiple_embeddings() {
+
+ Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
+ Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
+
+ List ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
+ assertThat(ids).hasSize(2);
+
+ List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
+ assertThat(relevant).hasSize(2);
+
+ EmbeddingMatch firstMatch = relevant.get(0);
+ assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
+ assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
+ assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
+ assertThat(firstMatch.embedded()).isNull();
+
+ EmbeddingMatch secondMatch = relevant.get(1);
+ assertThat(secondMatch.score()).isBetween(0d, 1d);
+ assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
+ assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
+ assertThat(secondMatch.embedded()).isNull();
+ }
+
+ @Test
+ void should_add_multiple_embeddings_with_segments() {
+
+ TextSegment firstSegment = TextSegment.from(randomUUID());
+ Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
+ TextSegment secondSegment = TextSegment.from(randomUUID());
+ Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
+
+ List ids = embeddingStore.addAll(
+ asList(firstEmbedding, secondEmbedding),
+ asList(firstSegment, secondSegment)
+ );
+ assertThat(ids).hasSize(2);
+
+ List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
+ assertThat(relevant).hasSize(2);
+
+ EmbeddingMatch firstMatch = relevant.get(0);
+ assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
+ assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
+ assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
+ assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
+
+ EmbeddingMatch secondMatch = relevant.get(1);
+ assertThat(secondMatch.score()).isBetween(0d, 1d);
+ assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
+ assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
+ assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
+ }
+
+ @Test
+ void should_find_with_min_score() {
+
+ String firstId = randomUUID();
+ Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
+ embeddingStore.add(firstId, firstEmbedding);
+
+ String secondId = randomUUID();
+ Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
+ embeddingStore.add(secondId, secondEmbedding);
+
+ List> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
+ assertThat(relevant).hasSize(2);
+ EmbeddingMatch firstMatch = relevant.get(0);
+ assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
+ assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
+ EmbeddingMatch secondMatch = relevant.get(1);
+ assertThat(secondMatch.score()).isBetween(0d, 1d);
+ assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
+
+ List> relevant2 = embeddingStore.findRelevant(
+ firstEmbedding,
+ 10,
+ secondMatch.score() - 0.01
+ );
+ assertThat(relevant2).hasSize(2);
+ assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
+ assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
+
+ List> relevant3 = embeddingStore.findRelevant(
+ firstEmbedding,
+ 10,
+ secondMatch.score()
+ );
+ assertThat(relevant3).hasSize(2);
+ assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
+ assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
+
+ List> relevant4 = embeddingStore.findRelevant(
+ firstEmbedding,
+ 10,
+ secondMatch.score() + 0.01
+ );
+ assertThat(relevant4).hasSize(1);
+ assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
+ }
+
+ @Test
+ void should_return_correct_score() {
+
+ Embedding embedding = embeddingModel.embed("hello").content();
+
+ String id = embeddingStore.add(embedding);
+ assertThat(id).isNotNull();
+
+ Embedding referenceEmbedding = embeddingModel.embed("hi").content();
+
+ List> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
+ assertThat(relevant).hasSize(1);
+
+ EmbeddingMatch match = relevant.get(0);
+ assertThat(match.score()).isCloseTo(
+ RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
+ withPercentage(1)
+ );
}
}
\ No newline at end of file