DefaultRetrievalAugmentor: process in the same thread when single query and single retriever, otherwise use executor and cache threads for 1 second (#1479)
## Issue This PR partially fixes #1454 ## Context `DefaultRetrievalAugmentor` currently uses an `Executor` to parallelize the processing (consider multiple `Query`s and/or multiple `ContentRetriever`s). The default `Executor` instance caches (non-daemon) threads for 60 seconds, so when the application is ready to shut down, it can hang for another 60 seconds before it can actually exit. For the majority of the use cases (single `Query` and single `ContentRetriever`) there is no need to use an `Executor`, processing can be done in the same thread without an `Executor`, thus there will be no hanging. For the rest of the use cases we can use the `Executor` to parallelize the processing as before. But for default `Executor`, reduce the time from 60 to 1 second, which makes "handing time" acceptable. In any case, the user can always provide a custom instance of an `Executor` and manage it externally. ## Change - Changes in `DefaultRetrievalAugmentor`: - When there is only a single `Query` and a single `ContentRetriever` (majority of use cases), processing is done in the same thread (`Executor` is not used at all) - Otherwise, the `Executor`is used to parallelize query routing and content retrieval. The default `Executor` now caches threads for 1 second (instead of 60 seconds) - Added javadoc and documentation - Added documentation for https://github.com/langchain4j/langchain4j/issues/1454 ## General checklist - [X] There are no breaking changes - [X] I have added unit and integration tests for my change - [x] I have manually run all the unit and integration tests in the module I have added/changed, and they are all green - [x] I have manually run all the unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules, and they are all green - [X] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added/updated [Spring Boot starter(s)](https://github.com/langchain4j/langchain4j-spring) (if applicable)
This commit is contained in:
parent
c1c696bff6
commit
4f09db98b8
|
@ -8,9 +8,6 @@ LangChain4j provides a few popular local embedding models packaged as maven depe
|
|||
They are powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html)
|
||||
and are running in the same java process.
|
||||
|
||||
By default, embedding is parallelized using all available CPU cores.
|
||||
Embedding using GPU is not supported yet.
|
||||
|
||||
Each model is provided in 2 flavours: original and quantized (has a `-q` suffix in maven artifact name and `Quantized` in the class name).
|
||||
|
||||
For example:
|
||||
|
@ -44,6 +41,24 @@ Embedding embedding = response.content();
|
|||
The complete list of all embedding models can be found [here](https://github.com/langchain4j/langchain4j-embeddings).
|
||||
|
||||
|
||||
## Parallelization
|
||||
|
||||
By default, the embedding process is parallelized using all available CPU cores,
|
||||
so each `TextSegment` is embedded in a separate thread.
|
||||
|
||||
The parallelization is done by using an `Executor`.
|
||||
By default, in-process embedding models use a cached thread pool
|
||||
with the number of threads equal to the number of available processors.
|
||||
Threads are cached for 1 second.
|
||||
|
||||
You can provide a custom instance of the `Executor` when creating a model:
|
||||
```java
|
||||
Executor = ...;
|
||||
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(executor);
|
||||
```
|
||||
|
||||
Embedding using GPU is not supported yet.
|
||||
|
||||
## Custom models
|
||||
|
||||
Many models (e.g., from [Hugging Face](https://huggingface.co/)) can be used,
|
||||
|
|
|
@ -474,6 +474,7 @@ RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
|
|||
.queryRouter(...)
|
||||
.contentAggregator(...)
|
||||
.contentInjector(...)
|
||||
.executor(...)
|
||||
.build();
|
||||
|
||||
Assistant assistant = AiServices.builder(Assistant.class)
|
||||
|
|
|
@ -743,6 +743,20 @@ More details are coming soon.
|
|||
|
||||
More details are coming soon.
|
||||
|
||||
### Parallelization
|
||||
|
||||
When there is only a single `Query` and a single `ContentRetriever`,
|
||||
`DefaultRetrievalAugmentor` performs query routing and content retrieval in the same thread.
|
||||
Otherwise, an `Executor` is used to parallelize the processing.
|
||||
By default, a modified (`keepAliveTime` is 1 second instead of 60 seconds) `Executors.newCachedThreadPool()`
|
||||
is used, but you can provide a custom `Executor` instance when creating the `DefaultRetrievalAugmentor`:
|
||||
```java
|
||||
DefaultRetrievalAugmentor.builder()
|
||||
...
|
||||
.executor(executor)
|
||||
.build;
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
- [Easy RAG](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_1_easy/Easy_RAG_Example.java)
|
||||
|
|
|
@ -21,16 +21,14 @@ import org.slf4j.LoggerFactory;
|
|||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.*;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
|
||||
import static java.util.Collections.*;
|
||||
import static java.util.concurrent.CompletableFuture.allOf;
|
||||
import static java.util.concurrent.CompletableFuture.supplyAsync;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static java.util.stream.Collectors.*;
|
||||
|
||||
/**
|
||||
|
@ -90,8 +88,11 @@ import static java.util.stream.Collectors.*;
|
|||
* Nonetheless, you are encouraged to use one of the advanced ready-to-use implementations or create a custom one.
|
||||
* <br>
|
||||
* <br>
|
||||
* By default, query routing and content retrieval are performed concurrently (for efficiency)
|
||||
* using {@link Executors#newCachedThreadPool()}, but you can provide a custom {@link Executor}.
|
||||
* When there is only a single {@link Query} and a single {@link ContentRetriever},
|
||||
* query routing and content retrieval are performed in the same thread.
|
||||
* Otherwise, an {@link Executor} is used to parallelize the processing.
|
||||
* By default, a modified (keepAliveTime is 1 second instead of 60 seconds) {@link Executors#newCachedThreadPool()}
|
||||
* is used, but you can provide a custom {@link Executor} instance.
|
||||
*
|
||||
* @see DefaultQueryTransformer
|
||||
* @see DefaultQueryRouter
|
||||
|
@ -118,7 +119,15 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
this.queryRouter = ensureNotNull(queryRouter, "queryRouter");
|
||||
this.contentAggregator = getOrDefault(contentAggregator, DefaultContentAggregator::new);
|
||||
this.contentInjector = getOrDefault(contentInjector, DefaultContentInjector::new);
|
||||
this.executor = getOrDefault(executor, Executors::newCachedThreadPool);
|
||||
this.executor = getOrDefault(executor, DefaultRetrievalAugmentor::createDefaultExecutor);
|
||||
}
|
||||
|
||||
private static ExecutorService createDefaultExecutor() {
|
||||
return new ThreadPoolExecutor(
|
||||
0, Integer.MAX_VALUE,
|
||||
1, SECONDS,
|
||||
new SynchronousQueue<>()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -142,20 +151,7 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
Collection<Query> queries = queryTransformer.transform(originalQuery);
|
||||
logQueries(originalQuery, queries);
|
||||
|
||||
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
|
||||
queries.forEach(query -> {
|
||||
CompletableFuture<Collection<List<Content>>> futureContents =
|
||||
supplyAsync(() -> {
|
||||
Collection<ContentRetriever> retrievers = queryRouter.route(query);
|
||||
log(query, retrievers);
|
||||
return retrievers;
|
||||
},
|
||||
executor
|
||||
).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
|
||||
queryToFutureContents.put(query, futureContents);
|
||||
});
|
||||
|
||||
Map<Query, Collection<List<Content>>> queryToContents = join(queryToFutureContents);
|
||||
Map<Query, Collection<List<Content>>> queryToContents = process(queries);
|
||||
|
||||
List<Content> contents = contentAggregator.aggregate(queryToContents);
|
||||
log(queryToContents, contents);
|
||||
|
@ -169,6 +165,39 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
|
|||
.build();
|
||||
}
|
||||
|
||||
private Map<Query, Collection<List<Content>>> process(Collection<Query> queries) {
|
||||
if (queries.size() == 1) {
|
||||
Query query = queries.iterator().next();
|
||||
Collection<ContentRetriever> retrievers = queryRouter.route(query);
|
||||
if (retrievers.size() == 1) {
|
||||
ContentRetriever contentRetriever = retrievers.iterator().next();
|
||||
List<Content> contents = contentRetriever.retrieve(query);
|
||||
return singletonMap(query, singletonList(contents));
|
||||
} else if (retrievers.size() > 1) {
|
||||
Collection<List<Content>> contents = retrieveFromAll(retrievers, query).join();
|
||||
return singletonMap(query, contents);
|
||||
} else {
|
||||
return emptyMap();
|
||||
}
|
||||
} else if (queries.size() > 1) {
|
||||
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
|
||||
queries.forEach(query -> {
|
||||
CompletableFuture<Collection<List<Content>>> futureContents =
|
||||
supplyAsync(() -> {
|
||||
Collection<ContentRetriever> retrievers = queryRouter.route(query);
|
||||
log(query, retrievers);
|
||||
return retrievers;
|
||||
},
|
||||
executor
|
||||
).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
|
||||
queryToFutureContents.put(query, futureContents);
|
||||
});
|
||||
return join(queryToFutureContents);
|
||||
} else {
|
||||
return emptyMap();
|
||||
}
|
||||
}
|
||||
|
||||
private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers,
|
||||
Query query) {
|
||||
List<CompletableFuture<List<Content>>> futureContents = retrievers.stream()
|
||||
|
|
|
@ -10,9 +10,10 @@ import dev.langchain4j.rag.query.Metadata;
|
|||
import dev.langchain4j.rag.query.Query;
|
||||
import dev.langchain4j.rag.query.router.DefaultQueryRouter;
|
||||
import dev.langchain4j.rag.query.router.QueryRouter;
|
||||
import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer;
|
||||
import dev.langchain4j.rag.query.transformer.QueryTransformer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
||||
import java.util.Collection;
|
||||
|
@ -25,6 +26,7 @@ import java.util.stream.Stream;
|
|||
|
||||
import static java.util.Arrays.asList;
|
||||
import static java.util.Collections.emptyList;
|
||||
import static java.util.Collections.singletonList;
|
||||
import static java.util.stream.Collectors.joining;
|
||||
import static java.util.stream.Collectors.toList;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
@ -34,7 +36,7 @@ class DefaultRetrievalAugmentorTest {
|
|||
|
||||
@ParameterizedTest
|
||||
@MethodSource("executors")
|
||||
void should_augment_user_message(Executor executor) {
|
||||
void should_augment_user_message__multiple_queries_multiple_retrievers(Executor executor) {
|
||||
|
||||
// given
|
||||
Query query1 = Query.from("query 1");
|
||||
|
@ -123,6 +125,152 @@ class DefaultRetrievalAugmentorTest {
|
|||
verifyNoMoreInteractions(contentInjector);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_augment_user_message__single_query_multiple_retrievers() {
|
||||
|
||||
// given
|
||||
QueryTransformer queryTransformer = spy(new DefaultQueryTransformer());
|
||||
|
||||
Content content1 = Content.from("content 1");
|
||||
Content content2 = Content.from("content 2");
|
||||
ContentRetriever contentRetriever1 = spy(new TestContentRetriever(content1, content2));
|
||||
|
||||
Content content3 = Content.from("content 3");
|
||||
Content content4 = Content.from("content 4");
|
||||
ContentRetriever contentRetriever2 = spy(new TestContentRetriever(content3, content4));
|
||||
|
||||
QueryRouter queryRouter = spy(new DefaultQueryRouter(contentRetriever1, contentRetriever2));
|
||||
|
||||
ContentAggregator contentAggregator = spy(new TestContentAggregator());
|
||||
|
||||
ContentInjector contentInjector = spy(new TestContentInjector());
|
||||
|
||||
Executor executor = spy(new TestExecutor());
|
||||
|
||||
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
|
||||
.queryTransformer(queryTransformer)
|
||||
.queryRouter(queryRouter)
|
||||
.contentAggregator(contentAggregator)
|
||||
.contentInjector(contentInjector)
|
||||
.executor(executor)
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.from("query");
|
||||
|
||||
Metadata metadata = Metadata.from(userMessage, null, null);
|
||||
|
||||
// when
|
||||
UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);
|
||||
|
||||
// then
|
||||
assertThat(augmented.singleText()).isEqualTo(
|
||||
"query\n" +
|
||||
"content 1\n" +
|
||||
"content 2\n" +
|
||||
"content 3\n" +
|
||||
"content 4"
|
||||
);
|
||||
|
||||
Query query = Query.from("query", metadata);
|
||||
verify(queryTransformer).transform(query);
|
||||
verifyNoMoreInteractions(queryTransformer);
|
||||
|
||||
verify(queryRouter).route(query);
|
||||
verifyNoMoreInteractions(queryRouter);
|
||||
|
||||
verify(contentRetriever1).retrieve(query);
|
||||
verifyNoMoreInteractions(contentRetriever1);
|
||||
|
||||
verify(contentRetriever2).retrieve(query);
|
||||
verifyNoMoreInteractions(contentRetriever2);
|
||||
|
||||
Map<Query, Collection<List<Content>>> queryToContents = new HashMap<>();
|
||||
queryToContents.put(query, asList(
|
||||
asList(content1, content2),
|
||||
asList(content3, content4)
|
||||
|
||||
));
|
||||
verify(contentAggregator).aggregate(queryToContents);
|
||||
verifyNoMoreInteractions(contentAggregator);
|
||||
|
||||
verify(contentInjector).inject(asList(content1, content2, content3, content4), userMessage);
|
||||
verify(contentInjector).inject(asList(content1, content2, content3, content4), (ChatMessage) userMessage);
|
||||
verifyNoMoreInteractions(contentInjector);
|
||||
|
||||
verify(executor, times(2)).execute(any());
|
||||
verifyNoMoreInteractions(executor);
|
||||
}
|
||||
|
||||
private static class TestExecutor implements Executor {
|
||||
|
||||
@Override
|
||||
public void execute(Runnable command) {
|
||||
command.run();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_augment_user_message__single_query_single_retriever() {
|
||||
|
||||
// given
|
||||
QueryTransformer queryTransformer = spy(new DefaultQueryTransformer());
|
||||
|
||||
Content content1 = Content.from("content 1");
|
||||
Content content2 = Content.from("content 2");
|
||||
ContentRetriever contentRetriever = spy(new TestContentRetriever(content1, content2));
|
||||
|
||||
QueryRouter queryRouter = spy(new DefaultQueryRouter(contentRetriever));
|
||||
|
||||
ContentAggregator contentAggregator = spy(new TestContentAggregator());
|
||||
|
||||
ContentInjector contentInjector = spy(new TestContentInjector());
|
||||
|
||||
Executor executor = mock(Executor.class);
|
||||
|
||||
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
|
||||
.queryTransformer(queryTransformer)
|
||||
.queryRouter(queryRouter)
|
||||
.contentAggregator(contentAggregator)
|
||||
.contentInjector(contentInjector)
|
||||
.executor(executor)
|
||||
.build();
|
||||
|
||||
UserMessage userMessage = UserMessage.from("query");
|
||||
|
||||
Metadata metadata = Metadata.from(userMessage, null, null);
|
||||
|
||||
// when
|
||||
UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);
|
||||
|
||||
// then
|
||||
assertThat(augmented.singleText()).isEqualTo(
|
||||
"query\n" +
|
||||
"content 1\n" +
|
||||
"content 2"
|
||||
);
|
||||
|
||||
Query query = Query.from("query", metadata);
|
||||
verify(queryTransformer).transform(query);
|
||||
verifyNoMoreInteractions(queryTransformer);
|
||||
|
||||
verify(queryRouter).route(query);
|
||||
verifyNoMoreInteractions(queryRouter);
|
||||
|
||||
verify(contentRetriever).retrieve(query);
|
||||
verifyNoMoreInteractions(contentRetriever);
|
||||
|
||||
Map<Query, Collection<List<Content>>> queryToContents = new HashMap<>();
|
||||
queryToContents.put(query, singletonList(asList(content1, content2)));
|
||||
verify(contentAggregator).aggregate(queryToContents);
|
||||
verifyNoMoreInteractions(contentAggregator);
|
||||
|
||||
verify(contentInjector).inject(asList(content1, content2), userMessage);
|
||||
verify(contentInjector).inject(asList(content1, content2), (ChatMessage) userMessage);
|
||||
verifyNoMoreInteractions(contentInjector);
|
||||
|
||||
verifyNoInteractions(executor);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@MethodSource("executors")
|
||||
void should_not_augment_when_router_does_not_return_retrievers(Executor executor) {
|
||||
|
@ -150,13 +298,15 @@ class DefaultRetrievalAugmentorTest {
|
|||
verifyNoMoreInteractions(queryRouter);
|
||||
}
|
||||
|
||||
static Stream<Arguments> executors() {
|
||||
return Stream.<Arguments>builder()
|
||||
.add(Arguments.of(Executors.newCachedThreadPool()))
|
||||
.add(Arguments.of(Executors.newFixedThreadPool(1)))
|
||||
.add(Arguments.of(Executors.newFixedThreadPool(2)))
|
||||
.add(Arguments.of(Executors.newFixedThreadPool(3)))
|
||||
.add(Arguments.of(Executors.newFixedThreadPool(4)))
|
||||
static Stream<Executor> executors() {
|
||||
return Stream.<Executor>builder()
|
||||
.add(Executors.newCachedThreadPool())
|
||||
.add(Executors.newFixedThreadPool(1))
|
||||
.add(Executors.newFixedThreadPool(2))
|
||||
.add(Executors.newFixedThreadPool(3))
|
||||
.add(Executors.newFixedThreadPool(4))
|
||||
.add(Runnable::run) // same thread executor
|
||||
.add(null) // to use default Executor in DefaultRetrievalAugmentor
|
||||
.build();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue