DefaultRetrievalAugmentor: process in the same thread when single query and single retriever, otherwise use executor and cache threads for 1 second (#1479)

## Issue
This PR partially fixes #1454

## Context
`DefaultRetrievalAugmentor` currently uses an `Executor` to parallelize
the processing (consider multiple `Query`s and/or multiple
`ContentRetriever`s).
The default `Executor` instance caches (non-daemon) threads for 60
seconds, so when the application is ready to shut down, it can hang for
another 60 seconds before it can actually exit.
For the majority of the use cases (single `Query` and single
`ContentRetriever`) there is no need to use an `Executor`, processing
can be done in the same thread without an `Executor`, thus there will be
no hanging.
For the rest of the use cases we can use the `Executor` to parallelize
the processing as before. But for default `Executor`, reduce the time
from 60 to 1 second, which makes "handing time" acceptable.
In any case, the user can always provide a custom instance of an
`Executor` and manage it externally.


## Change
- Changes in `DefaultRetrievalAugmentor`:
- When there is only a single `Query` and a single `ContentRetriever`
(majority of use cases), processing is done in the same thread
(`Executor` is not used at all)
- Otherwise, the `Executor`is used to parallelize query routing and
content retrieval. The default `Executor` now caches threads for 1
second (instead of 60 seconds)
  - Added javadoc and documentation
- Added documentation for
https://github.com/langchain4j/langchain4j/issues/1454



## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [X] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
This commit is contained in:
LangChain4j 2024-07-23 11:49:53 +02:00 committed by GitHub
parent c1c696bff6
commit 4f09db98b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 243 additions and 34 deletions

View File

@ -8,9 +8,6 @@ LangChain4j provides a few popular local embedding models packaged as maven depe
They are powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html)
and are running in the same java process.
By default, embedding is parallelized using all available CPU cores.
Embedding using GPU is not supported yet.
Each model is provided in 2 flavours: original and quantized (has a `-q` suffix in maven artifact name and `Quantized` in the class name).
For example:
@ -44,6 +41,24 @@ Embedding embedding = response.content();
The complete list of all embedding models can be found [here](https://github.com/langchain4j/langchain4j-embeddings).
## Parallelization
By default, the embedding process is parallelized using all available CPU cores,
so each `TextSegment` is embedded in a separate thread.
The parallelization is done by using an `Executor`.
By default, in-process embedding models use a cached thread pool
with the number of threads equal to the number of available processors.
Threads are cached for 1 second.
You can provide a custom instance of the `Executor` when creating a model:
```java
Executor = ...;
EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(executor);
```
Embedding using GPU is not supported yet.
## Custom models
Many models (e.g., from [Hugging Face](https://huggingface.co/)) can be used,

View File

@ -474,6 +474,7 @@ RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryRouter(...)
.contentAggregator(...)
.contentInjector(...)
.executor(...)
.build();
Assistant assistant = AiServices.builder(Assistant.class)

View File

@ -743,6 +743,20 @@ More details are coming soon.
More details are coming soon.
### Parallelization
When there is only a single `Query` and a single `ContentRetriever`,
`DefaultRetrievalAugmentor` performs query routing and content retrieval in the same thread.
Otherwise, an `Executor` is used to parallelize the processing.
By default, a modified (`keepAliveTime` is 1 second instead of 60 seconds) `Executors.newCachedThreadPool()`
is used, but you can provide a custom `Executor` instance when creating the `DefaultRetrievalAugmentor`:
```java
DefaultRetrievalAugmentor.builder()
...
.executor(executor)
.build;
```
## Examples
- [Easy RAG](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_1_easy/Easy_RAG_Example.java)

View File

@ -21,16 +21,14 @@ import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.*;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Collections.*;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.*;
/**
@ -90,8 +88,11 @@ import static java.util.stream.Collectors.*;
* Nonetheless, you are encouraged to use one of the advanced ready-to-use implementations or create a custom one.
* <br>
* <br>
* By default, query routing and content retrieval are performed concurrently (for efficiency)
* using {@link Executors#newCachedThreadPool()}, but you can provide a custom {@link Executor}.
* When there is only a single {@link Query} and a single {@link ContentRetriever},
* query routing and content retrieval are performed in the same thread.
* Otherwise, an {@link Executor} is used to parallelize the processing.
* By default, a modified (keepAliveTime is 1 second instead of 60 seconds) {@link Executors#newCachedThreadPool()}
* is used, but you can provide a custom {@link Executor} instance.
*
* @see DefaultQueryTransformer
* @see DefaultQueryRouter
@ -118,7 +119,15 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
this.queryRouter = ensureNotNull(queryRouter, "queryRouter");
this.contentAggregator = getOrDefault(contentAggregator, DefaultContentAggregator::new);
this.contentInjector = getOrDefault(contentInjector, DefaultContentInjector::new);
this.executor = getOrDefault(executor, Executors::newCachedThreadPool);
this.executor = getOrDefault(executor, DefaultRetrievalAugmentor::createDefaultExecutor);
}
private static ExecutorService createDefaultExecutor() {
return new ThreadPoolExecutor(
0, Integer.MAX_VALUE,
1, SECONDS,
new SynchronousQueue<>()
);
}
/**
@ -142,20 +151,7 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
Collection<Query> queries = queryTransformer.transform(originalQuery);
logQueries(originalQuery, queries);
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
queries.forEach(query -> {
CompletableFuture<Collection<List<Content>>> futureContents =
supplyAsync(() -> {
Collection<ContentRetriever> retrievers = queryRouter.route(query);
log(query, retrievers);
return retrievers;
},
executor
).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
queryToFutureContents.put(query, futureContents);
});
Map<Query, Collection<List<Content>>> queryToContents = join(queryToFutureContents);
Map<Query, Collection<List<Content>>> queryToContents = process(queries);
List<Content> contents = contentAggregator.aggregate(queryToContents);
log(queryToContents, contents);
@ -169,6 +165,39 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
.build();
}
private Map<Query, Collection<List<Content>>> process(Collection<Query> queries) {
if (queries.size() == 1) {
Query query = queries.iterator().next();
Collection<ContentRetriever> retrievers = queryRouter.route(query);
if (retrievers.size() == 1) {
ContentRetriever contentRetriever = retrievers.iterator().next();
List<Content> contents = contentRetriever.retrieve(query);
return singletonMap(query, singletonList(contents));
} else if (retrievers.size() > 1) {
Collection<List<Content>> contents = retrieveFromAll(retrievers, query).join();
return singletonMap(query, contents);
} else {
return emptyMap();
}
} else if (queries.size() > 1) {
Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<>();
queries.forEach(query -> {
CompletableFuture<Collection<List<Content>>> futureContents =
supplyAsync(() -> {
Collection<ContentRetriever> retrievers = queryRouter.route(query);
log(query, retrievers);
return retrievers;
},
executor
).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
queryToFutureContents.put(query, futureContents);
});
return join(queryToFutureContents);
} else {
return emptyMap();
}
}
private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers,
Query query) {
List<CompletableFuture<List<Content>>> futureContents = retrievers.stream()

View File

@ -10,9 +10,10 @@ import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.DefaultQueryRouter;
import dev.langchain4j.rag.query.router.QueryRouter;
import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.Collection;
@ -25,6 +26,7 @@ import java.util.stream.Stream;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
@ -34,7 +36,7 @@ class DefaultRetrievalAugmentorTest {
@ParameterizedTest
@MethodSource("executors")
void should_augment_user_message(Executor executor) {
void should_augment_user_message__multiple_queries_multiple_retrievers(Executor executor) {
// given
Query query1 = Query.from("query 1");
@ -123,6 +125,152 @@ class DefaultRetrievalAugmentorTest {
verifyNoMoreInteractions(contentInjector);
}
@Test
void should_augment_user_message__single_query_multiple_retrievers() {
// given
QueryTransformer queryTransformer = spy(new DefaultQueryTransformer());
Content content1 = Content.from("content 1");
Content content2 = Content.from("content 2");
ContentRetriever contentRetriever1 = spy(new TestContentRetriever(content1, content2));
Content content3 = Content.from("content 3");
Content content4 = Content.from("content 4");
ContentRetriever contentRetriever2 = spy(new TestContentRetriever(content3, content4));
QueryRouter queryRouter = spy(new DefaultQueryRouter(contentRetriever1, contentRetriever2));
ContentAggregator contentAggregator = spy(new TestContentAggregator());
ContentInjector contentInjector = spy(new TestContentInjector());
Executor executor = spy(new TestExecutor());
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.queryRouter(queryRouter)
.contentAggregator(contentAggregator)
.contentInjector(contentInjector)
.executor(executor)
.build();
UserMessage userMessage = UserMessage.from("query");
Metadata metadata = Metadata.from(userMessage, null, null);
// when
UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);
// then
assertThat(augmented.singleText()).isEqualTo(
"query\n" +
"content 1\n" +
"content 2\n" +
"content 3\n" +
"content 4"
);
Query query = Query.from("query", metadata);
verify(queryTransformer).transform(query);
verifyNoMoreInteractions(queryTransformer);
verify(queryRouter).route(query);
verifyNoMoreInteractions(queryRouter);
verify(contentRetriever1).retrieve(query);
verifyNoMoreInteractions(contentRetriever1);
verify(contentRetriever2).retrieve(query);
verifyNoMoreInteractions(contentRetriever2);
Map<Query, Collection<List<Content>>> queryToContents = new HashMap<>();
queryToContents.put(query, asList(
asList(content1, content2),
asList(content3, content4)
));
verify(contentAggregator).aggregate(queryToContents);
verifyNoMoreInteractions(contentAggregator);
verify(contentInjector).inject(asList(content1, content2, content3, content4), userMessage);
verify(contentInjector).inject(asList(content1, content2, content3, content4), (ChatMessage) userMessage);
verifyNoMoreInteractions(contentInjector);
verify(executor, times(2)).execute(any());
verifyNoMoreInteractions(executor);
}
private static class TestExecutor implements Executor {
@Override
public void execute(Runnable command) {
command.run();
}
}
@Test
void should_augment_user_message__single_query_single_retriever() {
// given
QueryTransformer queryTransformer = spy(new DefaultQueryTransformer());
Content content1 = Content.from("content 1");
Content content2 = Content.from("content 2");
ContentRetriever contentRetriever = spy(new TestContentRetriever(content1, content2));
QueryRouter queryRouter = spy(new DefaultQueryRouter(contentRetriever));
ContentAggregator contentAggregator = spy(new TestContentAggregator());
ContentInjector contentInjector = spy(new TestContentInjector());
Executor executor = mock(Executor.class);
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.queryRouter(queryRouter)
.contentAggregator(contentAggregator)
.contentInjector(contentInjector)
.executor(executor)
.build();
UserMessage userMessage = UserMessage.from("query");
Metadata metadata = Metadata.from(userMessage, null, null);
// when
UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);
// then
assertThat(augmented.singleText()).isEqualTo(
"query\n" +
"content 1\n" +
"content 2"
);
Query query = Query.from("query", metadata);
verify(queryTransformer).transform(query);
verifyNoMoreInteractions(queryTransformer);
verify(queryRouter).route(query);
verifyNoMoreInteractions(queryRouter);
verify(contentRetriever).retrieve(query);
verifyNoMoreInteractions(contentRetriever);
Map<Query, Collection<List<Content>>> queryToContents = new HashMap<>();
queryToContents.put(query, singletonList(asList(content1, content2)));
verify(contentAggregator).aggregate(queryToContents);
verifyNoMoreInteractions(contentAggregator);
verify(contentInjector).inject(asList(content1, content2), userMessage);
verify(contentInjector).inject(asList(content1, content2), (ChatMessage) userMessage);
verifyNoMoreInteractions(contentInjector);
verifyNoInteractions(executor);
}
@ParameterizedTest
@MethodSource("executors")
void should_not_augment_when_router_does_not_return_retrievers(Executor executor) {
@ -150,13 +298,15 @@ class DefaultRetrievalAugmentorTest {
verifyNoMoreInteractions(queryRouter);
}
static Stream<Arguments> executors() {
return Stream.<Arguments>builder()
.add(Arguments.of(Executors.newCachedThreadPool()))
.add(Arguments.of(Executors.newFixedThreadPool(1)))
.add(Arguments.of(Executors.newFixedThreadPool(2)))
.add(Arguments.of(Executors.newFixedThreadPool(3)))
.add(Arguments.of(Executors.newFixedThreadPool(4)))
static Stream<Executor> executors() {
return Stream.<Executor>builder()
.add(Executors.newCachedThreadPool())
.add(Executors.newFixedThreadPool(1))
.add(Executors.newFixedThreadPool(2))
.add(Executors.newFixedThreadPool(3))
.add(Executors.newFixedThreadPool(4))
.add(Runnable::run) // same thread executor
.add(null) // to use default Executor in DefaultRetrievalAugmentor
.build();
}