Fixed relevance score calculation (#164)

This commit is contained in:
LangChain4j 2023-09-07 19:19:20 +02:00 committed by GitHub
parent f2bb6f992e
commit b804d03ca8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 131 additions and 56 deletions

View File

@ -0,0 +1,62 @@
package dev.langchain4j.store.embedding;
import dev.langchain4j.data.embedding.Embedding;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
public class CosineSimilarity {
/**
* Calculates cosine similarity between two vectors.
* <p>
* Cosine similarity measures the cosine of the angle between two vectors, indicating their directional similarity.
* It produces a value in the range:
* <p>
* -1 indicates vectors are diametrically opposed (opposite directions).
* <p>
* 0 indicates vectors are orthogonal (no directional similarity).
* <p>
* 1 indicates vectors are pointing in the same direction (but not necessarily of the same magnitude).
* <p>
* Not to be confused with cosine distance ([0..2]), which quantifies how different two vectors are.
*
* @param embeddingA first embedding vector
* @param embeddingB second embedding vector
* @return cosine similarity in the range [-1..1]
*/
public static double between(Embedding embeddingA, Embedding embeddingB) {
ensureNotNull(embeddingA, "embeddingA");
ensureNotNull(embeddingB, "embeddingB");
float[] vectorA = embeddingA.vector();
float[] vectorB = embeddingB.vector();
if (vectorA.length != vectorB.length) {
throw illegalArgument("Length of vector a (%s) must be equal to the length of vector b (%s)",
vectorA.length, vectorB.length);
}
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += vectorA[i] * vectorA[i];
normB += vectorB[i] * vectorB[i];
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
/**
* Converts relevance score into cosine similarity.
*
* @param relevanceScore Relevance score in the range [0..1] where 0 is not relevant and 1 is relevant.
* @return Cosine similarity in the range [-1..1] where -1 is not relevant and 1 is relevant.
*/
public static double fromRelevanceScore(double relevanceScore) {
return relevanceScore * 2 - 1;
}
}

View File

@ -3,14 +3,12 @@ package dev.langchain4j.store.embedding;
public class RelevanceScore {
/**
* Calculates the relevance score between two vectors using cosine similarity.
* Converts cosine similarity into relevance score.
*
* @param a first vector
* @param b second vector
* @return score in the range [0, 1], where 0 indicates no relevance and 1 indicates full relevance
* @param cosineSimilarity Cosine similarity in the range [-1..1] where -1 is not relevant and 1 is relevant.
* @return Relevance score in the range [0..1] where 0 is not relevant and 1 is relevant.
*/
public static double cosine(float[] a, float[] b) {
double cosineSimilarity = Similarity.cosine(a, b);
return 1 - (1 - cosineSimilarity) / 2;
public static double fromCosineSimilarity(double cosineSimilarity) {
return (cosineSimilarity + 1) / 2;
}
}

View File

@ -1,43 +0,0 @@
package dev.langchain4j.store.embedding;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
public class Similarity {
/**
* Calculates cosine similarity between two vectors.
* <p>
* Cosine similarity measures the cosine of the angle between two vectors, indicating their directional similarity.
* It produces a value in the range:
* - -1 indicates vectors are diametrically opposed (opposite directions).
* - 0 indicates vectors are orthogonal (no directional similarity).
* - 1 indicates vectors are pointing in the same direction (but not necessarily of the same magnitude).
* <p>
* Not to be confused with cosine distance ([0..2]), which quantifies how different two vectors are.
*
* @param a first vector
* @param b second vector
* @return cosine similarity in the range [-1..1]
*/
public static double cosine(float[] a, float[] b) {
ensureNotNull(a, "a");
ensureNotNull(b, "b");
if (a.length != b.length) {
throw illegalArgument("Length of vector a (%s) must be equal to the length of vector b (%s)",
a.length, b.length);
}
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
}

View File

@ -0,0 +1,26 @@
package dev.langchain4j.store.embedding;
import dev.langchain4j.data.embedding.Embedding;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
class CosineSimilarityTest {
@Test
void should_calculate_cosine_similarity() {
Embedding embeddingA = Embedding.from(new float[]{1, 1, 1});
Embedding embeddingB = Embedding.from(new float[]{-1, -1, -1});
assertThat(CosineSimilarity.between(embeddingA, embeddingA)).isCloseTo(1, withPercentage(1));
assertThat(CosineSimilarity.between(embeddingA, embeddingB)).isCloseTo(-1, withPercentage(1));
}
@Test
void should_convert_relevance_score_into_cosine_similarity() {
assertThat(CosineSimilarity.fromRelevanceScore(0)).isEqualTo(-1);
assertThat(CosineSimilarity.fromRelevanceScore(0.5)).isEqualTo(0);
assertThat(CosineSimilarity.fromRelevanceScore(1)).isEqualTo(1);
}
}

View File

@ -0,0 +1,15 @@
package dev.langchain4j.store.embedding;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
class RelevanceScoreTest {
@Test
void should_convert_cosine_similarity_into_relevance_score() {
assertThat(RelevanceScore.fromCosineSimilarity(-1)).isEqualTo(0);
assertThat(RelevanceScore.fromCosineSimilarity(0)).isEqualTo(0.5);
assertThat(RelevanceScore.fromCosineSimilarity(1)).isEqualTo(1);
}
}

View File

@ -4,6 +4,7 @@ import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
@ -176,9 +177,10 @@ public class PineconeEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {
.get(METADATA_TEXT_SEGMENT);
Embedding embedding = Embedding.from(vector.getValuesList());
double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding);
return new EmbeddingMatch<>(
RelevanceScore.cosine(embedding.vector(), referenceEmbedding.vector()),
RelevanceScore.fromCosineSimilarity(cosineSimilarity),
vector.getId(),
embedding,
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue())

View File

@ -2,11 +2,18 @@ package dev.langchain4j.classification;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.RelevanceScore;
import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static dev.langchain4j.internal.ValidationUtils.*;
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Comparator.comparingDouble;
import static java.util.stream.Collectors.toList;
@ -111,7 +118,8 @@ public class EmbeddingModelTextClassifier<E extends Enum<E>> implements TextClas
double meanScore = 0;
double maxScore = 0;
for (Embedding exampleEmbedding : exampleEmbeddings) {
double score = RelevanceScore.cosine(textEmbedding.vector(), exampleEmbedding.vector());
double cosineSimilarity = CosineSimilarity.between(textEmbedding, exampleEmbedding);
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
meanScore += score;
maxScore = Math.max(score, maxScore);
}

View File

@ -4,6 +4,7 @@ import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
@ -13,7 +14,12 @@ import java.lang.reflect.Type;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.PriorityQueue;
import static dev.langchain4j.internal.Utils.randomUUID;
import static java.nio.file.StandardOpenOption.CREATE;
@ -116,7 +122,8 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
for (Entry<Embedded> entry : entries) {
double score = RelevanceScore.cosine(entry.embedding.vector(), referenceEmbedding.vector());
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= minScore) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
if (matches.size() > maxResults) {