reducing duplication of *EmbeddingStoreIT
This commit is contained in:
parent
6aed16ab81
commit
bd802c352d
|
@ -21,9 +21,13 @@ public abstract class EmbeddingStoreWithoutMetadataIT {
|
|||
|
||||
@BeforeEach
|
||||
void beforeEach() {
|
||||
clearStore();
|
||||
ensureStoreIsEmpty();
|
||||
}
|
||||
|
||||
protected void clearStore() {
|
||||
}
|
||||
|
||||
protected void ensureStoreIsEmpty() {
|
||||
Embedding embedding = embeddingModel().embed("hello").content();
|
||||
assertThat(embeddingStore().findRelevant(embedding, 1000)).isEmpty();
|
||||
|
|
|
@ -39,6 +39,14 @@
|
|||
<artifactId>slf4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<classifier>tests</classifier>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-engine</artifactId>
|
||||
|
|
|
@ -16,8 +16,7 @@ import redis.clients.jedis.search.*;
|
|||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.isCollectionEmpty;
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static dev.langchain4j.internal.Utils.*;
|
||||
import static dev.langchain4j.internal.ValidationUtils.*;
|
||||
import static dev.langchain4j.store.embedding.redis.RedisSchema.SCORE_FIELD_NAME;
|
||||
import static java.lang.String.format;
|
||||
|
@ -46,13 +45,15 @@ public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
* @param port Redis Stack Server port
|
||||
* @param user Redis Stack username (optional)
|
||||
* @param password Redis Stack password (optional)
|
||||
* @param dimension embedding vector dimension
|
||||
* @param metadataFieldsName metadata fields name (optional)
|
||||
* @param indexName The name of the index (optional). Default value: "embedding-index".
|
||||
* @param dimension Embedding vector dimension
|
||||
* @param metadataFieldsName Metadata fields name (optional)
|
||||
*/
|
||||
public RedisEmbeddingStore(String host,
|
||||
Integer port,
|
||||
String user,
|
||||
String password,
|
||||
String indexName,
|
||||
Integer dimension,
|
||||
List<String> metadataFieldsName) {
|
||||
ensureNotBlank(host, "host");
|
||||
|
@ -61,6 +62,7 @@ public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
|
||||
this.client = user == null ? new JedisPooled(host, port) : new JedisPooled(host, port, user, password);
|
||||
this.schema = RedisSchema.builder()
|
||||
.indexName(getOrDefault(indexName, "embedding-index"))
|
||||
.dimension(dimension)
|
||||
.metadataFieldsName(metadataFieldsName)
|
||||
.build();
|
||||
|
@ -140,8 +142,8 @@ public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
}
|
||||
|
||||
private boolean isIndexExist(String indexName) {
|
||||
Set<String> indexSets = client.ftList();
|
||||
return indexSets.contains(indexName);
|
||||
Set<String> indexes = client.ftList();
|
||||
return indexes.contains(indexName);
|
||||
}
|
||||
|
||||
private void addInternal(String id, Embedding embedding, TextSegment embedded) {
|
||||
|
@ -218,6 +220,7 @@ public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
private Integer port;
|
||||
private String user;
|
||||
private String password;
|
||||
private String indexName;
|
||||
private Integer dimension;
|
||||
private List<String> metadataFieldsName = new ArrayList<>();
|
||||
|
||||
|
@ -253,6 +256,15 @@ public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param indexName The name of the index (optional). Default value: "embedding-index".
|
||||
* @return builder
|
||||
*/
|
||||
public Builder indexName(String indexName) {
|
||||
this.indexName = indexName;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param dimension embedding vector dimension
|
||||
* @return builder
|
||||
|
@ -271,7 +283,7 @@ public class RedisEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||
}
|
||||
|
||||
public RedisEmbeddingStore build() {
|
||||
return new RedisEmbeddingStore(host, port, user, password, dimension, metadataFieldsName);
|
||||
return new RedisEmbeddingStore(host, port, user, password, indexName, dimension, metadataFieldsName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,8 +29,7 @@ class RedisSchema {
|
|||
|
||||
/* Redis schema field settings */
|
||||
|
||||
@Builder.Default
|
||||
private String indexName = "embedding-index";
|
||||
private String indexName;
|
||||
@Builder.Default
|
||||
private String prefix = "embedding:";
|
||||
@Builder.Default
|
||||
|
|
|
@ -1,276 +1,60 @@
|
|||
package dev.langchain4j.store.embedding.redis;
|
||||
|
||||
import com.redis.testcontainers.RedisStackContainer;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import com.redis.testcontainers.RedisContainer;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.CosineSimilarity;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.RelevanceScore;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.junit.jupiter.api.TestInstance.Lifecycle;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreIT;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import redis.clients.jedis.JedisPooled;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static com.redis.testcontainers.RedisStackContainer.DEFAULT_IMAGE_NAME;
|
||||
import static com.redis.testcontainers.RedisStackContainer.DEFAULT_TAG;
|
||||
import static dev.langchain4j.internal.Utils.randomUUID;
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.data.Percentage.withPercentage;
|
||||
|
||||
@TestInstance(Lifecycle.PER_CLASS)
|
||||
class RedisEmbeddingStoreIT {
|
||||
class RedisEmbeddingStoreIT extends EmbeddingStoreIT {
|
||||
|
||||
/**
|
||||
* First start Redis locally:
|
||||
* docker pull redis/redis-stack:latest
|
||||
* docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest
|
||||
*/
|
||||
static RedisContainer redis = new RedisContainer(DEFAULT_IMAGE_NAME.withTag(DEFAULT_TAG));
|
||||
|
||||
private static final String METADATA_KEY = "test-key";
|
||||
EmbeddingStore<TextSegment> embeddingStore;
|
||||
|
||||
private final RedisStackContainer redis = new RedisStackContainer(DEFAULT_IMAGE_NAME.withTag(DEFAULT_TAG));
|
||||
|
||||
private EmbeddingStore<TextSegment> embeddingStore;
|
||||
|
||||
private final EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
|
||||
@BeforeAll
|
||||
void setup() {
|
||||
static void beforeAll() {
|
||||
redis.start();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
void teardown() {
|
||||
redis.close();
|
||||
static void afterAll() {
|
||||
redis.stop();
|
||||
}
|
||||
|
||||
@BeforeEach
|
||||
void initEmptyRedisEmbeddingStore() {
|
||||
|
||||
flushDB();
|
||||
|
||||
embeddingStore = RedisEmbeddingStore.builder()
|
||||
.host(redis.getHost())
|
||||
.port(redis.getFirstMappedPort())
|
||||
.dimension(384)
|
||||
.build();
|
||||
}
|
||||
|
||||
private void flushDB() {
|
||||
@Override
|
||||
protected void clearStore() {
|
||||
try (JedisPooled jedis = new JedisPooled(redis.getHost(), redis.getFirstMappedPort())) {
|
||||
jedis.flushDB();
|
||||
jedis.flushDB(); // TODO fix: why redis returns embeddings from different indexes?
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_embedding() {
|
||||
|
||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
String id = embeddingStore.add(embedding);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_embedding_with_id() {
|
||||
|
||||
String id = randomUUID();
|
||||
Embedding embedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
embeddingStore.add(id, embedding);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isNull();
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_embedding_with_segment() {
|
||||
|
||||
TextSegment segment = TextSegment.from(randomUUID());
|
||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
||||
|
||||
String id = embeddingStore.add(embedding, segment);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isEqualTo(segment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_embedding_with_segment_with_metadata() {
|
||||
|
||||
flushDB();
|
||||
|
||||
embeddingStore = RedisEmbeddingStore.builder()
|
||||
.host(redis.getHost())
|
||||
.port(redis.getFirstMappedPort())
|
||||
.indexName(randomUUID())
|
||||
.dimension(384)
|
||||
.metadataFieldsName(singletonList(METADATA_KEY))
|
||||
.metadataFieldsName(singletonList("test-key"))
|
||||
.build();
|
||||
|
||||
TextSegment segment = TextSegment.from(randomUUID(), Metadata.from(METADATA_KEY, "test-value"));
|
||||
Embedding embedding = embeddingModel.embed(segment.text()).content();
|
||||
|
||||
String id = embeddingStore.add(embedding, segment);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(embedding, 10);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(match.embeddingId()).isEqualTo(id);
|
||||
assertThat(match.embedding()).isEqualTo(embedding);
|
||||
assertThat(match.embedded()).isEqualTo(segment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings() {
|
||||
|
||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
|
||||
List<String> ids = embeddingStore.addAll(asList(firstEmbedding, secondEmbedding));
|
||||
assertThat(ids).hasSize(2);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isNull();
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isNull();
|
||||
@Override
|
||||
protected EmbeddingStore<TextSegment> embeddingStore() {
|
||||
return embeddingStore;
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_add_multiple_embeddings_with_segments() {
|
||||
|
||||
TextSegment firstSegment = TextSegment.from(randomUUID());
|
||||
Embedding firstEmbedding = embeddingModel.embed(firstSegment.text()).content();
|
||||
TextSegment secondSegment = TextSegment.from(randomUUID());
|
||||
Embedding secondEmbedding = embeddingModel.embed(secondSegment.text()).content();
|
||||
|
||||
List<String> ids = embeddingStore.addAll(
|
||||
asList(firstEmbedding, secondEmbedding),
|
||||
asList(firstSegment, secondSegment)
|
||||
);
|
||||
assertThat(ids).hasSize(2);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(ids.get(0));
|
||||
assertThat(firstMatch.embedding()).isEqualTo(firstEmbedding);
|
||||
assertThat(firstMatch.embedded()).isEqualTo(firstSegment);
|
||||
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(ids.get(1));
|
||||
assertThat(secondMatch.embedding()).isEqualTo(secondEmbedding);
|
||||
assertThat(secondMatch.embedded()).isEqualTo(secondSegment);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_find_with_min_score() {
|
||||
|
||||
String firstId = randomUUID();
|
||||
Embedding firstEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
embeddingStore.add(firstId, firstEmbedding);
|
||||
|
||||
String secondId = randomUUID();
|
||||
Embedding secondEmbedding = embeddingModel.embed(randomUUID()).content();
|
||||
embeddingStore.add(secondId, secondEmbedding);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(firstEmbedding, 10);
|
||||
assertThat(relevant).hasSize(2);
|
||||
EmbeddingMatch<TextSegment> firstMatch = relevant.get(0);
|
||||
assertThat(firstMatch.score()).isCloseTo(1, withPercentage(1));
|
||||
assertThat(firstMatch.embeddingId()).isEqualTo(firstId);
|
||||
EmbeddingMatch<TextSegment> secondMatch = relevant.get(1);
|
||||
assertThat(secondMatch.score()).isBetween(0d, 1d);
|
||||
assertThat(secondMatch.embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant2 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() - 0.01
|
||||
);
|
||||
assertThat(relevant2).hasSize(2);
|
||||
assertThat(relevant2.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant2.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant3 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score()
|
||||
);
|
||||
assertThat(relevant3).hasSize(2);
|
||||
assertThat(relevant3.get(0).embeddingId()).isEqualTo(firstId);
|
||||
assertThat(relevant3.get(1).embeddingId()).isEqualTo(secondId);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant4 = embeddingStore.findRelevant(
|
||||
firstEmbedding,
|
||||
10,
|
||||
secondMatch.score() + 0.01
|
||||
);
|
||||
assertThat(relevant4).hasSize(1);
|
||||
assertThat(relevant4.get(0).embeddingId()).isEqualTo(firstId);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_return_correct_score() {
|
||||
|
||||
Embedding embedding = embeddingModel.embed("hello").content();
|
||||
|
||||
String id = embeddingStore.add(embedding);
|
||||
assertThat(id).isNotNull();
|
||||
|
||||
Embedding referenceEmbedding = embeddingModel.embed("hi").content();
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> relevant = embeddingStore.findRelevant(referenceEmbedding, 1);
|
||||
assertThat(relevant).hasSize(1);
|
||||
|
||||
EmbeddingMatch<TextSegment> match = relevant.get(0);
|
||||
assertThat(match.score()).isCloseTo(
|
||||
RelevanceScore.fromCosineSimilarity(CosineSimilarity.between(embedding, referenceEmbedding)),
|
||||
withPercentage(1)
|
||||
);
|
||||
@Override
|
||||
protected EmbeddingModel embeddingModel() {
|
||||
return embeddingModel;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue