Chroma: added filtering by score
This commit is contained in:
parent
c1d0b8df32
commit
94f62cd538
|
@ -45,6 +45,21 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -11,6 +11,7 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
@ -33,13 +34,11 @@ public class ChromaEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
* @param timeout The timeout duration for the Chroma client. If not specified, 5 seconds will be used.
|
||||
*/
|
||||
public ChromaEmbeddingStore(String baseUrl, String collectionName, Duration timeout) {
|
||||
collectionName = collectionName == null ? "default" : collectionName;
|
||||
timeout = timeout == null ? ofSeconds(5) : timeout;
|
||||
collectionName = getOrDefault(collectionName, "default");
|
||||
|
||||
this.chromaClient = new ChromaClient(baseUrl, timeout);
|
||||
this.chromaClient = new ChromaClient(baseUrl, getOrDefault(timeout, ofSeconds(5)));
|
||||
|
||||
Collection collection = chromaClient.collection(collectionName);
|
||||
|
||||
if (collection == null) {
|
||||
Collection createdCollection = chromaClient.createCollection(new CreateCollectionRequest(collectionName));
|
||||
collectionId = createdCollection.id();
|
||||
|
@ -165,7 +164,11 @@ public class ChromaEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
|
||||
QueryResponse queryResponse = chromaClient.queryCollection(collectionId, queryRequest);
|
||||
|
||||
return toEmbeddingMatches(queryResponse);
|
||||
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(queryResponse);
|
||||
|
||||
return matches.stream()
|
||||
.filter(match -> match.score() >= minScore)
|
||||
.collect(toList());
|
||||
}
|
||||
|
||||
private static List<EmbeddingMatch<TextSegment>> toEmbeddingMatches(QueryResponse queryResponse) {
|
||||
|
@ -197,6 +200,6 @@ public class ChromaEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
private static TextSegment toTextSegment(QueryResponse queryResponse, int i) {
|
||||
String text = queryResponse.documents().get(0).get(i);
|
||||
Map<String, String> metadata = queryResponse.metadatas().get(0).get(i);
|
||||
return text == null ? null : TextSegment.from(text, metadata == null ? null : new Metadata(metadata));
|
||||
return text == null ? null : TextSegment.from(text, metadata == null ? new Metadata() : new Metadata(metadata));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,8 +3,11 @@ package dev.langchain4j.store.embedding.chroma;
|
|||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
@ -12,35 +15,220 @@ import org.junit.jupiter.api.Test;
|
|||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.util.Arrays.asList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
@Disabled("needs Chroma running locally")
|
||||
class ChromaEmbeddingStoreTest {
|
||||
|
||||
/**
|
||||
* First ensure you have Chroma running locally. If not, then:
|
||||
* - Execute "docker pull ghcr.io/chroma-core/chroma:0.4.6"
|
||||
* - Execute "docker run -d -p 8000:8000 ghcr.io/chroma-core/chroma:0.4.6"
|
||||
* - Wait until Chroma is ready to serve (may take a few minutes)
|
||||
*/
|
||||
|
||||
private final EmbeddingStore<TextSegment> embeddingStore = ChromaEmbeddingStore.builder()
|
||||
.baseUrl("http://localhost:8000")
|
||||
.collectionName(randomUUID())
|
||||
.build();
|
||||
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@Test
|
||||
@Disabled("To run this test, you must have a local Chroma instance")
|
||||
public void testAddEmbeddingAndFindRelevant() {
|
||||
void should_add_embedding() {
|
||||
|
||||
ChromaEmbeddingStore chromaEmbeddingStore = ChromaEmbeddingStore.builder()
|
||||
.baseUrl("http://localhost:8000")
|
||||
.collectionName(randomUUID())
|
||||
.build();
|
||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
|
||||
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
|
||||
String id = chromaEmbeddingStore.add(embedding, textSegment);
|
||||
assertThat(id).isNotBlank();
|
||||
String id = embeddingStore.add(embedding);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
|
||||
List<EmbeddingMatch<TextSegment>> embeddingMatches = chromaEmbeddingStore.findRelevant(refereceEmbedding, 10);
|
||||
assertThat(embeddingMatches).hasSize(1);
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
|
||||
assertThat(embeddingMatch.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, refereceEmbedding)),
|
||||
withPercentage(1));
|
||||
assertThat(embeddingMatch.embeddingId()).isEqualTo(id);
|
||||
assertThat(embeddingMatch.embedded()).isEqualTo(textSegment);
|
||||
assertThat(embeddingMatch.embedding()).isEqualTo(embedding);
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_embedding_with_id() {
|
||||
|
||||
String id = randomUUID();
|
||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
embeddingStore.add(id, embedding);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_embedding_with_segment() {
|
||||
|
||||
TextSegment segment = TextSegment.from(randomUUID());
|
||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
||||
|
||||
String id = embeddingStore.add(embedding, segment);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isEqualTo(segment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_embedding_with_segment_with_metadata() {
|
||||
|
||||
TextSegment segment = TextSegment.from(randomUUID(), Metadata.from("test-key", "test-value"));
|
||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
||||
|
||||
String id = embeddingStore.add(embedding, segment);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isEqualTo(segment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings() {
|
||||
|
||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
|
||||
assertThat(ids).hasSize(2);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isNull();
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings_with_segments() {
|
||||
|
||||
TextSegment firstSegment = TextSegment.from(randomUUID());
|
||||
Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
|
||||
TextSegment secondSegment = TextSegment.from(randomUUID());
|
||||
Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
|
||||
|
||||
List<String> ids = embeddingStore.addAll(
|
||||
asList(firstEmbedding, secondEmbedding),
|
||||
asList(firstSegment, secondSegment)
|
||||
);
|
||||
assertThat(ids).hasSize(2);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_find_with_min_score() {
|
||||
|
||||
String firstId = randomUUID();
|
||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
embeddingStore.add(firstId, firstEmbedding);
|
||||
|
||||
String secondId = randomUUID();
|
||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
embeddingStore.add(secondId, secondEmbedding);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() - 0.01
|
||||
);
|
||||
assertThat(relevant2).hasSize(2);
|
||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score()
|
||||
);
|
||||
assertThat(relevant3).hasSize(2);
|
||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() + 0.01
|
||||
);
|
||||
assertThat(relevant4).hasSize(1);
|
||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_correct_score() {
|
||||
|
||||
Embedding embedding = embeddingModel.embed("hello").content();
|
||||
|
||||
String id = embeddingStore.add(embedding);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
Embedding referenceEmbedding = embeddingModel.embed("hi").content();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
||||
withPercentage(1)
|
||||
);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue