EmbeddingStoreWithRemovalIT: use awaitility

This commit is contained in:
LangChain4j 2024-08-20 12:16:54 +02:00
parent ae589e0c8d
commit 2e6d386f3f
4 changed files with 107 additions and 235 deletions

View File

@ -165,7 +165,7 @@ public abstract class EmbeddingStoreWithRemovalIT {
return searchResult.matches();
}
private static void awaitUntilAsserted(ThrowingRunnable assertion) {
protected static void awaitUntilAsserted(ThrowingRunnable assertion) {
Awaitility.await()
.pollInterval(Duration.ofMillis(500))
.atMost(Duration.ofSeconds(15))

View File

@ -123,6 +123,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -0,0 +1,100 @@
package dev.langchain4j.store.embedding.elasticsearch;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import static dev.langchain4j.internal.Utils.randomUUID;
import static org.assertj.core.api.Assertions.assertThat;
class ElasticsearchEmbeddingStoreRemovalIT extends EmbeddingStoreWithRemovalIT {
static ElasticsearchClientHelper elasticsearchClientHelper = new ElasticsearchClientHelper();
EmbeddingStore<TextSegment> embeddingStore = ElasticsearchEmbeddingStore.builder()
.restClient(elasticsearchClientHelper.restClient)
.indexName(randomUUID())
.build();
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
String indexName;
@BeforeAll
static void startServices() throws IOException {
elasticsearchClientHelper.startServices();
assertThat(elasticsearchClientHelper.restClient).isNotNull();
assertThat(elasticsearchClientHelper.client).isNotNull();
}
@AfterAll
static void stopServices() throws IOException {
elasticsearchClientHelper.stopServices();
}
@BeforeEach
void createEmbeddingStore() throws IOException {
indexName = randomUUID();
elasticsearchClientHelper.removeDataStore(indexName);
embeddingStore = ElasticsearchEmbeddingStore.builder()
.restClient(elasticsearchClientHelper.restClient)
.indexName(indexName)
.build();
}
@AfterEach
void removeDataStore() throws IOException {
// We remove the indices in case we were running with a local test instance
// we don't keep dirty things around
elasticsearchClientHelper.removeDataStore(indexName);
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@Test
void should_remove_all() throws IOException {
// given
Embedding embedding1 = embeddingModel().embed("test1").content();
embeddingStore().add(embedding1);
Embedding embedding2 = embeddingModel().embed("test2").content();
embeddingStore().add(embedding2);
awaitUntilAsserted(() -> assertThat(getAllEmbeddings()).hasSize(2));
// when
embeddingStore().removeAll();
// then
assertThat(elasticsearchClientHelper.client.indices().exists(er -> er.index(indexName)).value()).isFalse();
}
@Test
void should_not_fail_to_remove_non_existing_datastore() throws IOException {
// when
embeddingStore.removeAll();
// then
assertThat(elasticsearchClientHelper.client.indices().exists(er -> er.index(indexName)).value()).isFalse();
}
}

View File

@ -1,234 +0,0 @@
package dev.langchain4j.store.embedding.elasticsearch;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT;
import org.junit.jupiter.api.*;
import org.testcontainers.shaded.org.awaitility.Awaitility;
import org.testcontainers.shaded.org.awaitility.core.ThrowingRunnable;
import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static dev.langchain4j.internal.Utils.randomUUID;
import static dev.langchain4j.store.embedding.filter.MetadataFilterBuilder.metadataKey;
import static org.assertj.core.api.Assertions.assertThat;
/**
* TODO add some methods like "EmbeddingStoreWithRemovalIT#wait_for_ready()"
* so we can remove the "specialized" implementations
*/
class ElasticsearchEmbeddingStoreRemoveIT extends EmbeddingStoreWithRemovalIT {
static ElasticsearchClientHelper elasticsearchClientHelper = new ElasticsearchClientHelper();
EmbeddingStore<TextSegment> embeddingStore = ElasticsearchEmbeddingStore.builder()
.restClient(elasticsearchClientHelper.restClient)
.indexName(randomUUID())
.build();
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
String indexName;
@BeforeAll
static void startServices() throws IOException {
elasticsearchClientHelper.startServices();
assertThat(elasticsearchClientHelper.restClient).isNotNull();
assertThat(elasticsearchClientHelper.client).isNotNull();
}
@AfterAll
static void stopServices() throws IOException {
elasticsearchClientHelper.stopServices();
}
@BeforeEach
void createEmbeddingStore() throws IOException {
indexName = randomUUID();
elasticsearchClientHelper.removeDataStore(indexName);
embeddingStore = ElasticsearchEmbeddingStore.builder()
.restClient(elasticsearchClientHelper.restClient)
.indexName(indexName)
.build();
}
@AfterEach
void removeDataStore() throws IOException {
// We remove the indices in case we were running with a local test instance
// we don't keep dirty things around
elasticsearchClientHelper.removeDataStore(indexName);
}
@Override
protected EmbeddingStore<TextSegment> embeddingStore() {
return embeddingStore;
}
@Override
protected EmbeddingModel embeddingModel() {
return embeddingModel;
}
@Test
void should_not_fail_to_remove_non_existing_datastore() throws IOException {
// given
// Nothing
// when
embeddingStore.removeAll();
// then
assertThat(elasticsearchClientHelper.client.indices().exists(er -> er.index(indexName)).value()).isFalse();
}
@Test
void should_remove_all() throws IOException {
// given
Embedding embedding = embeddingModel.embed("hello").content();
Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();
embeddingStore.add(embedding);
embeddingStore.add(embedding2);
embeddingStore.add(embedding3);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.maxResults(10)
.build();
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3));
// when
embeddingStore.removeAll();
// then
assertThat(elasticsearchClientHelper.client.indices().exists(er -> er.index(indexName)).value()).isFalse();
}
@Test
void should_remove_by_id() {
// given
Embedding embedding = embeddingModel.embed("hello").content();
Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();
String id = embeddingStore.add(embedding);
String id2 = embeddingStore.add(embedding2);
String id3 = embeddingStore.add(embedding3);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.maxResults(10)
.build();
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3));
// when
embeddingStore.remove(id);
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(2));
// then
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
List<String> matchingIds = matches.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(matchingIds).containsExactly(id2, id3);
}
@Test
void should_remove_all_by_ids() {
// given
Embedding embedding = embeddingModel.embed("hello").content();
Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();
String id = embeddingStore.add(embedding);
String id2 = embeddingStore.add(embedding2);
String id3 = embeddingStore.add(embedding3);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.maxResults(10)
.build();
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3));
// when
embeddingStore.removeAll(Arrays.asList(id2, id3));
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(1));
// then
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
List<String> matchingIds = matches.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(matchingIds).containsExactly(id);
}
@Test
void should_remove_all_by_filter() {
// given
Metadata metadata = Metadata.metadata("id", "1");
TextSegment segment = TextSegment.from("matching", metadata);
Embedding embedding = embeddingModel.embed(segment).content();
embeddingStore.add(embedding, segment);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.maxResults(10)
.build();
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(1));
Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();
String id2 = embeddingStore.add(embedding2);
String id3 = embeddingStore.add(embedding3);
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3));
// when
embeddingStore.removeAll(metadataKey("id").isEqualTo("1"));
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(2));
// then
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
List<String> matchingIds = matches.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(matchingIds).hasSize(2);
assertThat(matchingIds).containsExactly(id2, id3);
}
@Test
void should_remove_all_by_filter_not_matching() {
// given
Embedding embedding = embeddingModel.embed("hello").content();
Embedding embedding2 = embeddingModel.embed("hello2").content();
Embedding embedding3 = embeddingModel.embed("hello3").content();
embeddingStore.add(embedding);
embeddingStore.add(embedding2);
embeddingStore.add(embedding3);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embedding)
.maxResults(10)
.build();
awaitAssertion(() -> assertThat(embeddingStore.search(request).matches()).hasSize(3));
// when
embeddingStore.removeAll(metadataKey("unknown").isEqualTo("1"));
// then
List<EmbeddingMatch<TextSegment>> matches = embeddingStore.search(request).matches();
List<String> matchingIds = matches.stream().map(EmbeddingMatch::embeddingId).collect(Collectors.toList());
assertThat(matchingIds).hasSize(3);
}
private static void awaitAssertion(ThrowingRunnable assertionRunnable) {
Awaitility.await().pollInterval(Duration.ofSeconds(1))
.atMost(Duration.ofSeconds(5))
.untilAsserted(assertionRunnable);
}
}