Fixed relevance score calculation (#164)
This commit is contained in:
parent
f2bb6f992e
commit
b804d03ca8
|
@ -0,0 +1,62 @@
|
|||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
public class CosineSimilarity {
|
||||
|
||||
/**
|
||||
* Calculates cosine similarity between two vectors.
|
||||
* <p>
|
||||
* Cosine similarity measures the cosine of the angle between two vectors, indicating their directional similarity.
|
||||
* It produces a value in the range:
|
||||
* <p>
|
||||
* -1 indicates vectors are diametrically opposed (opposite directions).
|
||||
* <p>
|
||||
* 0 indicates vectors are orthogonal (no directional similarity).
|
||||
* <p>
|
||||
* 1 indicates vectors are pointing in the same direction (but not necessarily of the same magnitude).
|
||||
* <p>
|
||||
* Not to be confused with cosine distance ([0..2]), which quantifies how different two vectors are.
|
||||
*
|
||||
* @param embeddingA first embedding vector
|
||||
* @param embeddingB second embedding vector
|
||||
* @return cosine similarity in the range [-1..1]
|
||||
*/
|
||||
public static double between(Embedding embeddingA, Embedding embeddingB) {
|
||||
ensureNotNull(embeddingA, "embeddingA");
|
||||
ensureNotNull(embeddingB, "embeddingB");
|
||||
|
||||
float[] vectorA = embeddingA.vector();
|
||||
float[] vectorB = embeddingB.vector();
|
||||
|
||||
if (vectorA.length != vectorB.length) {
|
||||
throw illegalArgument("Length of vector a (%s) must be equal to the length of vector b (%s)",
|
||||
vectorA.length, vectorB.length);
|
||||
}
|
||||
|
||||
double dotProduct = 0.0;
|
||||
double normA = 0.0;
|
||||
double normB = 0.0;
|
||||
|
||||
for (int i = 0; i < vectorA.length; i++) {
|
||||
dotProduct += vectorA[i] * vectorB[i];
|
||||
normA += vectorA[i] * vectorA[i];
|
||||
normB += vectorB[i] * vectorB[i];
|
||||
}
|
||||
|
||||
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts relevance score into cosine similarity.
|
||||
*
|
||||
* @param relevanceScore Relevance score in the range [0..1] where 0 is not relevant and 1 is relevant.
|
||||
* @return Cosine similarity in the range [-1..1] where -1 is not relevant and 1 is relevant.
|
||||
*/
|
||||
public static double fromRelevanceScore(double relevanceScore) {
|
||||
return relevanceScore * 2 - 1;
|
||||
}
|
||||
}
|
|
@ -3,14 +3,12 @@ package dev.langchain4j.store.embedding;
|
|||
public class RelevanceScore {
|
||||
|
||||
/**
|
||||
* Calculates the relevance score between two vectors using cosine similarity.
|
||||
* Converts cosine similarity into relevance score.
|
||||
*
|
||||
* @param a first vector
|
||||
* @param b second vector
|
||||
* @return score in the range [0, 1], where 0 indicates no relevance and 1 indicates full relevance
|
||||
* @param cosineSimilarity Cosine similarity in the range [-1..1] where -1 is not relevant and 1 is relevant.
|
||||
* @return Relevance score in the range [0..1] where 0 is not relevant and 1 is relevant.
|
||||
*/
|
||||
public static double cosine(float[] a, float[] b) {
|
||||
double cosineSimilarity = Similarity.cosine(a, b);
|
||||
return 1 - (1 - cosineSimilarity) / 2;
|
||||
public static double fromCosineSimilarity(double cosineSimilarity) {
|
||||
return (cosineSimilarity + 1) / 2;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import static dev.langchain4j.internal.Exceptions.illegalArgument;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
|
||||
public class Similarity {
|
||||
|
||||
/**
|
||||
* Calculates cosine similarity between two vectors.
|
||||
* <p>
|
||||
* Cosine similarity measures the cosine of the angle between two vectors, indicating their directional similarity.
|
||||
* It produces a value in the range:
|
||||
* - -1 indicates vectors are diametrically opposed (opposite directions).
|
||||
* - 0 indicates vectors are orthogonal (no directional similarity).
|
||||
* - 1 indicates vectors are pointing in the same direction (but not necessarily of the same magnitude).
|
||||
* <p>
|
||||
* Not to be confused with cosine distance ([0..2]), which quantifies how different two vectors are.
|
||||
*
|
||||
* @param a first vector
|
||||
* @param b second vector
|
||||
* @return cosine similarity in the range [-1..1]
|
||||
*/
|
||||
public static double cosine(float[] a, float[] b) {
|
||||
ensureNotNull(a, "a");
|
||||
ensureNotNull(b, "b");
|
||||
if (a.length != b.length) {
|
||||
throw illegalArgument("Length of vector a (%s) must be equal to the length of vector b (%s)",
|
||||
a.length, b.length);
|
||||
}
|
||||
|
||||
double dotProduct = 0.0;
|
||||
double normA = 0.0;
|
||||
double normB = 0.0;
|
||||
|
||||
for (int i = 0; i < a.length; i++) {
|
||||
dotProduct += a[i] * b[i];
|
||||
normA += a[i] * a[i];
|
||||
normB += b[i] * b[i];
|
||||
}
|
||||
|
||||
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
class CosineSimilarityTest {
|
||||
|
||||
@Test
|
||||
void should_calculate_cosine_similarity() {
|
||||
Embedding embeddingA = Embedding.from(new float[]{1, 1, 1});
|
||||
Embedding embeddingB = Embedding.from(new float[]{-1, -1, -1});
|
||||
|
||||
assertThat(CosineSimilarity.between(embeddingA, embeddingA)).isCloseTo(1, withPercentage(1));
|
||||
assertThat(CosineSimilarity.between(embeddingA, embeddingB)).isCloseTo(-1, withPercentage(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_convert_relevance_score_into_cosine_similarity() {
|
||||
assertThat(CosineSimilarity.fromRelevanceScore(0)).isEqualTo(-1);
|
||||
assertThat(CosineSimilarity.fromRelevanceScore(0.5)).isEqualTo(0);
|
||||
assertThat(CosineSimilarity.fromRelevanceScore(1)).isEqualTo(1);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
package dev.langchain4j.store.embedding;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
class RelevanceScoreTest {
|
||||
|
||||
@Test
|
||||
void should_convert_cosine_similarity_into_relevance_score() {
|
||||
assertThat(RelevanceScore.fromCosineSimilarity(-1)).isEqualTo(0);
|
||||
assertThat(RelevanceScore.fromCosineSimilarity(0)).isEqualTo(0.5);
|
||||
assertThat(RelevanceScore.fromCosineSimilarity(1)).isEqualTo(1);
|
||||
}
|
||||
}
|
|
@ -4,6 +4,7 @@ import com.google.protobuf.Struct;
|
|||
import com.google.protobuf.Value;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
|
@ -176,9 +177,10 @@ public class PineconeEmbeddingStoreImpl implements EmbeddingStore<TextSegment> {
|
|||
.get(METADATA_TEXT_SEGMENT);
|
||||
|
||||
Embedding embedding = Embedding.from(vector.getValuesList());
|
||||
double cosineSimilarity = CosineSimilarity.between(embedding, referenceEmbedding);
|
||||
|
||||
return new EmbeddingMatch<>(
|
||||
RelevanceScore.cosine(embedding.vector(), referenceEmbedding.vector()),
|
||||
RelevanceScore.fromCosineSimilarity(cosineSimilarity),
|
||||
vector.getId(),
|
||||
embedding,
|
||||
textSegmentValue == null ? null : TextSegment.from(textSegmentValue.getStringValue())
|
||||
|
|
|
@ -2,11 +2,18 @@ package dev.langchain4j.classification;
|
|||
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static java.util.Comparator.comparingDouble;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
|
||||
|
@ -111,7 +118,8 @@ public class EmbeddingModelTextClassifier<E extends Enum<E>> implements TextClas
|
|||
double meanScore = 0;
|
||||
double maxScore = 0;
|
||||
for (Embedding exampleEmbedding : exampleEmbeddings) {
|
||||
double score = RelevanceScore.cosine(textEmbedding.vector(), exampleEmbedding.vector());
|
||||
double cosineSimilarity = CosineSimilarity.between(textEmbedding, exampleEmbedding);
|
||||
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
||||
meanScore += score;
|
||||
maxScore = Math.max(score, maxScore);
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import com.google.gson.Gson;
|
|||
import com.google.gson.reflect.TypeToken;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
|
@ -13,7 +14,12 @@ import java.lang.reflect.Type;
|
|||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.PriorityQueue;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.nio.file.StandardOpenOption.CREATE;
|
||||
|
@ -116,7 +122,8 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
|
|||
PriorityQueue<EmbeddingMatch<Embedded>> matches = new PriorityQueue<>(comparator);
|
||||
|
||||
for (Entry<Embedded> entry : entries) {
|
||||
double score = RelevanceScore.cosine(entry.embedding.vector(), referenceEmbedding.vector());
|
||||
double cosineSimilarity = CosineSimilarity.between(entry.embedding, referenceEmbedding);
|
||||
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
||||
if (score >= minScore) {
|
||||
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
|
||||
if (matches.size() > maxResults) {
|
||||
|
|
Loading…
Reference in New Issue