Redis: more tests
This commit is contained in:
parent
cd5f405b75
commit
f2b2f0214a
|
@ -46,14 +46,8 @@
|
|||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-junit-jupiter</artifactId>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
|
@ -69,6 +63,21 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<licenses>
|
||||
<license>
|
||||
<name>Apache-2.0</name>
|
||||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
|
||||
<distribution>repo</distribution>
|
||||
<comments>A business-friendly OSS license</comments>
|
||||
</license>
|
||||
</licenses>
|
||||
|
||||
</project>
|
|
@ -219,7 +219,7 @@ public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
private String user;
|
||||
private String password;
|
||||
private Integer dimension;
|
||||
private List<String> metadataFieldsName;
|
||||
private List<String> metadataFieldsName = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* @param host Redis Stack host
|
||||
|
|
|
@ -3,17 +3,24 @@ package dev.langchain4j.store.embedding.redis;
|
|||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.internal.Utils;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
@Disabled("needs Redis running locally")
|
||||
class RedisEmbeddingStoreTest {
|
||||
|
@ -24,73 +31,234 @@ class RedisEmbeddingStoreTest {
|
|||
* docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
|
||||
*/
|
||||
|
||||
private final EmbeddingStore<TextSegment> store = new RedisEmbeddingStore(
|
||||
"localhost",
|
||||
6379,
|
||||
"default",
|
||||
"password",
|
||||
4,
|
||||
singletonList("field")
|
||||
);
|
||||
private static final String HOST = "localhost";
|
||||
private static final int PORT = 6379;
|
||||
private static final String METADATA_KEY = "test-key";
|
||||
|
||||
@Test
|
||||
void testAdd() {
|
||||
// test add without id
|
||||
String id = store.add(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)),
|
||||
TextSegment.from("test string", Metadata.from("field", "value")));
|
||||
System.out.println("id=" + id);
|
||||
private EmbeddingStore<TextSegment> embeddingStore;
|
||||
|
||||
// test add with id
|
||||
String selfId = Utils.randomUUID();
|
||||
store.add(selfId, Embedding.from(asList(0.80f, 0.45f, 0.89f, 0.24f)));
|
||||
System.out.println("id=" + selfId);
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@BeforeEach
|
||||
void initEmptyRedisEmbeddingStore() {
|
||||
|
||||
flushDB();
|
||||
|
||||
embeddingStore = RedisEmbeddingStore.builder()
|
||||
.host(HOST)
|
||||
.port(PORT)
|
||||
.dimension(384)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static void flushDB() {
|
||||
try (JedisPooled jedis = new JedisPooled(HOST, PORT)) {
|
||||
jedis.flushDB();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddAll() {
|
||||
// test add All Method without embedded
|
||||
List<String> ids = store.addAll(asList(
|
||||
Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)),
|
||||
Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f)),
|
||||
Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f))
|
||||
));
|
||||
System.out.println("ids=" + ids);
|
||||
void should_add_embedding() {
|
||||
|
||||
// test add all method with embedded
|
||||
ids = store.addAll(asList(
|
||||
Embedding.from(asList(0.3f, 0.87f, 0.90f, 0.24f)),
|
||||
Embedding.from(asList(0.54f, 0.34f, 0.67f, 0.24f)),
|
||||
Embedding.from(asList(0.80f, 0.45f, 0.779f, 0.5556f))
|
||||
), asList(
|
||||
TextSegment.from("testString1", Metadata.from("field", "value1")),
|
||||
TextSegment.from("testString2", Metadata.from("field", "value2")),
|
||||
TextSegment.from("testingString3", Metadata.from("field", "value3"))
|
||||
));
|
||||
System.out.println("ids=" + ids);
|
||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
String id = embeddingStore.add(embedding);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testAddEmpty() {
|
||||
// see log
|
||||
store.addAll(emptyList());
|
||||
void should_add_embedding_with_id() {
|
||||
|
||||
String id = randomUUID();
|
||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
embeddingStore.add(id, embedding);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFindRelevant() {
|
||||
List<EmbeddingMatch<TextSegment>> res = store.findRelevant(Embedding.from(asList(0.80f, 0.45f, 0.89f, 0.24f)), 5);
|
||||
res.forEach(System.out::println);
|
||||
void should_add_embedding_with_segment() {
|
||||
|
||||
TextSegment segment = TextSegment.from(randomUUID());
|
||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
||||
|
||||
String id = embeddingStore.add(embedding, segment);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isEqualTo(segment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testScore() {
|
||||
String id = store.add(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)),
|
||||
TextSegment.from("test string", Metadata.from("field", "value")));
|
||||
System.out.println("id=" + id);
|
||||
void should_add_embedding_with_segment_with_metadata() {
|
||||
|
||||
// use the same embedding to search
|
||||
List<EmbeddingMatch<TextSegment>> res = store.findRelevant(Embedding.from(asList(0.50f, 0.85f, 0.760f, 0.24f)), 1);
|
||||
res.forEach(System.out::println);
|
||||
flushDB();
|
||||
|
||||
// the result embeddingMatch score is 5.96046447754E-8, but expected is 1 because they are same vectors.
|
||||
embeddingStore = RedisEmbeddingStore.builder()
|
||||
.host(HOST)
|
||||
.port(PORT)
|
||||
.dimension(384)
|
||||
.metadataFieldsName(singletonList(METADATA_KEY))
|
||||
.build();
|
||||
|
||||
TextSegment segment = TextSegment.from(randomUUID(), Metadata.from(METADATA_KEY, "test-value"));
|
||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
||||
|
||||
String id = embeddingStore.add(embedding, segment);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isEqualTo(segment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings() {
|
||||
|
||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
|
||||
assertThat(ids).hasSize(2);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isNull();
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings_with_segments() {
|
||||
|
||||
TextSegment firstSegment = TextSegment.from(randomUUID());
|
||||
Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
|
||||
TextSegment secondSegment = TextSegment.from(randomUUID());
|
||||
Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
|
||||
|
||||
List<String> ids = embeddingStore.addAll(
|
||||
asList(firstEmbedding, secondEmbedding),
|
||||
asList(firstSegment, secondSegment)
|
||||
);
|
||||
assertThat(ids).hasSize(2);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_find_with_min_score() {
|
||||
|
||||
String firstId = randomUUID();
|
||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
embeddingStore.add(firstId, firstEmbedding);
|
||||
|
||||
String secondId = randomUUID();
|
||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
embeddingStore.add(secondId, secondEmbedding);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() - 0.01
|
||||
);
|
||||
assertThat(relevant2).hasSize(2);
|
||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score()
|
||||
);
|
||||
assertThat(relevant3).hasSize(2);
|
||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() + 0.01
|
||||
);
|
||||
assertThat(relevant4).hasSize(1);
|
||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_correct_score() {
|
||||
|
||||
Embedding embedding = embeddingModel.embed("hello").content();
|
||||
|
||||
String id = embeddingStore.add(embedding);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
Embedding referenceEmbedding = embeddingModel.embed("hi").content();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
||||
withPercentage(1)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue