diff --git a/docs/docs/integrations/embedding-models/1-in-process.md b/docs/docs/integrations/embedding-models/1-in-process.md
index b7e35ce95..b8827d56d 100644
--- a/docs/docs/integrations/embedding-models/1-in-process.md
+++ b/docs/docs/integrations/embedding-models/1-in-process.md
@@ -8,9 +8,6 @@ LangChain4j provides a few popular local embedding models packaged as maven depe
They are powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html)
and are running in the same java process.
-By default, embedding is parallelized using all available CPU cores.
-Embedding using GPU is not supported yet.
-
Each model is provided in 2 flavours: original and quantized (has a `-q` suffix in maven artifact name and `Quantized` in the class name).
For example:
@@ -44,6 +41,24 @@ Embedding embedding = response.content();
The complete list of all embedding models can be found [here](https://github.com/langchain4j/langchain4j-embeddings).
+## Parallelization
+
+By default, the embedding process is parallelized using all available CPU cores,
+so each `TextSegment` is embedded in a separate thread.
+
+The parallelization is done by using an `Executor`.
+By default, in-process embedding models use a cached thread pool
+with the number of threads equal to the number of available processors.
+Threads are cached for 1 second.
+
+You can provide a custom instance of the `Executor` when creating a model:
+```java
+Executor = ...;
+EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(executor);
+```
+
+Embedding using GPU is not supported yet.
+
## Custom models
Many models (e.g., from [Hugging Face](https://huggingface.co/)) can be used,
diff --git a/docs/docs/tutorials/5-ai-services.md b/docs/docs/tutorials/5-ai-services.md
index 1e0a1651c..d1177aab8 100644
--- a/docs/docs/tutorials/5-ai-services.md
+++ b/docs/docs/tutorials/5-ai-services.md
@@ -474,6 +474,7 @@ RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryRouter(...)
.contentAggregator(...)
.contentInjector(...)
+ .executor(...)
.build();
Assistant assistant = AiServices.builder(Assistant.class)
diff --git a/docs/docs/tutorials/7-rag.md b/docs/docs/tutorials/7-rag.md
index ff000204d..89eb35d02 100644
--- a/docs/docs/tutorials/7-rag.md
+++ b/docs/docs/tutorials/7-rag.md
@@ -743,6 +743,20 @@ More details are coming soon.
More details are coming soon.
+### Parallelization
+
+When there is only a single `Query` and a single `ContentRetriever`,
+`DefaultRetrievalAugmentor` performs query routing and content retrieval in the same thread.
+Otherwise, an `Executor` is used to parallelize the processing.
+By default, a modified (`keepAliveTime` is 1 second instead of 60 seconds) `Executors.newCachedThreadPool()`
+is used, but you can provide a custom `Executor` instance when creating the `DefaultRetrievalAugmentor`:
+```java
+DefaultRetrievalAugmentor.builder()
+ ...
+ .executor(executor)
+ .build;
+```
+
## Examples
- [Easy RAG](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_1_easy/Easy_RAG_Example.java)
diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java
index da58e57d1..db228db14 100644
--- a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java
+++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java
@@ -21,16 +21,14 @@ import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.List;
import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
+import java.util.concurrent.*;
import static dev.langchain4j.internal.Utils.getOrDefault;
-import static dev.langchain4j.internal.Utils.isNotNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
+import static java.util.Collections.*;
import static java.util.concurrent.CompletableFuture.allOf;
import static java.util.concurrent.CompletableFuture.supplyAsync;
+import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.*;
/**
@@ -90,8 +88,11 @@ import static java.util.stream.Collectors.*;
* Nonetheless, you are encouraged to use one of the advanced ready-to-use implementations or create a custom one.
*
*
- * By default, query routing and content retrieval are performed concurrently (for efficiency)
- * using {@link Executors#newCachedThreadPool()}, but you can provide a custom {@link Executor}.
+ * When there is only a single {@link Query} and a single {@link ContentRetriever},
+ * query routing and content retrieval are performed in the same thread.
+ * Otherwise, an {@link Executor} is used to parallelize the processing.
+ * By default, a modified (keepAliveTime is 1 second instead of 60 seconds) {@link Executors#newCachedThreadPool()}
+ * is used, but you can provide a custom {@link Executor} instance.
*
* @see DefaultQueryTransformer
* @see DefaultQueryRouter
@@ -118,7 +119,15 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
this.queryRouter = ensureNotNull(queryRouter, "queryRouter");
this.contentAggregator = getOrDefault(contentAggregator, DefaultContentAggregator::new);
this.contentInjector = getOrDefault(contentInjector, DefaultContentInjector::new);
- this.executor = getOrDefault(executor, Executors::newCachedThreadPool);
+ this.executor = getOrDefault(executor, DefaultRetrievalAugmentor::createDefaultExecutor);
+ }
+
+ private static ExecutorService createDefaultExecutor() {
+ return new ThreadPoolExecutor(
+ 0, Integer.MAX_VALUE,
+ 1, SECONDS,
+ new SynchronousQueue<>()
+ );
}
/**
@@ -142,20 +151,7 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
Collection queries = queryTransformer.transform(originalQuery);
logQueries(originalQuery, queries);
- Map>>> queryToFutureContents = new ConcurrentHashMap<>();
- queries.forEach(query -> {
- CompletableFuture>> futureContents =
- supplyAsync(() -> {
- Collection retrievers = queryRouter.route(query);
- log(query, retrievers);
- return retrievers;
- },
- executor
- ).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
- queryToFutureContents.put(query, futureContents);
- });
-
- Map>> queryToContents = join(queryToFutureContents);
+ Map>> queryToContents = process(queries);
List contents = contentAggregator.aggregate(queryToContents);
log(queryToContents, contents);
@@ -169,6 +165,39 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
.build();
}
+ private Map>> process(Collection queries) {
+ if (queries.size() == 1) {
+ Query query = queries.iterator().next();
+ Collection retrievers = queryRouter.route(query);
+ if (retrievers.size() == 1) {
+ ContentRetriever contentRetriever = retrievers.iterator().next();
+ List contents = contentRetriever.retrieve(query);
+ return singletonMap(query, singletonList(contents));
+ } else if (retrievers.size() > 1) {
+ Collection> contents = retrieveFromAll(retrievers, query).join();
+ return singletonMap(query, contents);
+ } else {
+ return emptyMap();
+ }
+ } else if (queries.size() > 1) {
+ Map>>> queryToFutureContents = new ConcurrentHashMap<>();
+ queries.forEach(query -> {
+ CompletableFuture>> futureContents =
+ supplyAsync(() -> {
+ Collection retrievers = queryRouter.route(query);
+ log(query, retrievers);
+ return retrievers;
+ },
+ executor
+ ).thenCompose(retrievers -> retrieveFromAll(retrievers, query));
+ queryToFutureContents.put(query, futureContents);
+ });
+ return join(queryToFutureContents);
+ } else {
+ return emptyMap();
+ }
+ }
+
private CompletableFuture>> retrieveFromAll(Collection retrievers,
Query query) {
List>> futureContents = retrievers.stream()
diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java
index f9a673740..ca622a4f4 100644
--- a/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java
+++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java
@@ -10,9 +10,10 @@ import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.DefaultQueryRouter;
import dev.langchain4j.rag.query.router.QueryRouter;
+import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
+import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.Collection;
@@ -25,6 +26,7 @@ import java.util.stream.Stream;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
+import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
@@ -34,7 +36,7 @@ class DefaultRetrievalAugmentorTest {
@ParameterizedTest
@MethodSource("executors")
- void should_augment_user_message(Executor executor) {
+ void should_augment_user_message__multiple_queries_multiple_retrievers(Executor executor) {
// given
Query query1 = Query.from("query 1");
@@ -123,6 +125,152 @@ class DefaultRetrievalAugmentorTest {
verifyNoMoreInteractions(contentInjector);
}
+ @Test
+ void should_augment_user_message__single_query_multiple_retrievers() {
+
+ // given
+ QueryTransformer queryTransformer = spy(new DefaultQueryTransformer());
+
+ Content content1 = Content.from("content 1");
+ Content content2 = Content.from("content 2");
+ ContentRetriever contentRetriever1 = spy(new TestContentRetriever(content1, content2));
+
+ Content content3 = Content.from("content 3");
+ Content content4 = Content.from("content 4");
+ ContentRetriever contentRetriever2 = spy(new TestContentRetriever(content3, content4));
+
+ QueryRouter queryRouter = spy(new DefaultQueryRouter(contentRetriever1, contentRetriever2));
+
+ ContentAggregator contentAggregator = spy(new TestContentAggregator());
+
+ ContentInjector contentInjector = spy(new TestContentInjector());
+
+ Executor executor = spy(new TestExecutor());
+
+ RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
+ .queryTransformer(queryTransformer)
+ .queryRouter(queryRouter)
+ .contentAggregator(contentAggregator)
+ .contentInjector(contentInjector)
+ .executor(executor)
+ .build();
+
+ UserMessage userMessage = UserMessage.from("query");
+
+ Metadata metadata = Metadata.from(userMessage, null, null);
+
+ // when
+ UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);
+
+ // then
+ assertThat(augmented.singleText()).isEqualTo(
+ "query\n" +
+ "content 1\n" +
+ "content 2\n" +
+ "content 3\n" +
+ "content 4"
+ );
+
+ Query query = Query.from("query", metadata);
+ verify(queryTransformer).transform(query);
+ verifyNoMoreInteractions(queryTransformer);
+
+ verify(queryRouter).route(query);
+ verifyNoMoreInteractions(queryRouter);
+
+ verify(contentRetriever1).retrieve(query);
+ verifyNoMoreInteractions(contentRetriever1);
+
+ verify(contentRetriever2).retrieve(query);
+ verifyNoMoreInteractions(contentRetriever2);
+
+ Map>> queryToContents = new HashMap<>();
+ queryToContents.put(query, asList(
+ asList(content1, content2),
+ asList(content3, content4)
+
+ ));
+ verify(contentAggregator).aggregate(queryToContents);
+ verifyNoMoreInteractions(contentAggregator);
+
+ verify(contentInjector).inject(asList(content1, content2, content3, content4), userMessage);
+ verify(contentInjector).inject(asList(content1, content2, content3, content4), (ChatMessage) userMessage);
+ verifyNoMoreInteractions(contentInjector);
+
+ verify(executor, times(2)).execute(any());
+ verifyNoMoreInteractions(executor);
+ }
+
+ private static class TestExecutor implements Executor {
+
+ @Override
+ public void execute(Runnable command) {
+ command.run();
+ }
+ }
+
+ @Test
+ void should_augment_user_message__single_query_single_retriever() {
+
+ // given
+ QueryTransformer queryTransformer = spy(new DefaultQueryTransformer());
+
+ Content content1 = Content.from("content 1");
+ Content content2 = Content.from("content 2");
+ ContentRetriever contentRetriever = spy(new TestContentRetriever(content1, content2));
+
+ QueryRouter queryRouter = spy(new DefaultQueryRouter(contentRetriever));
+
+ ContentAggregator contentAggregator = spy(new TestContentAggregator());
+
+ ContentInjector contentInjector = spy(new TestContentInjector());
+
+ Executor executor = mock(Executor.class);
+
+ RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
+ .queryTransformer(queryTransformer)
+ .queryRouter(queryRouter)
+ .contentAggregator(contentAggregator)
+ .contentInjector(contentInjector)
+ .executor(executor)
+ .build();
+
+ UserMessage userMessage = UserMessage.from("query");
+
+ Metadata metadata = Metadata.from(userMessage, null, null);
+
+ // when
+ UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);
+
+ // then
+ assertThat(augmented.singleText()).isEqualTo(
+ "query\n" +
+ "content 1\n" +
+ "content 2"
+ );
+
+ Query query = Query.from("query", metadata);
+ verify(queryTransformer).transform(query);
+ verifyNoMoreInteractions(queryTransformer);
+
+ verify(queryRouter).route(query);
+ verifyNoMoreInteractions(queryRouter);
+
+ verify(contentRetriever).retrieve(query);
+ verifyNoMoreInteractions(contentRetriever);
+
+ Map>> queryToContents = new HashMap<>();
+ queryToContents.put(query, singletonList(asList(content1, content2)));
+ verify(contentAggregator).aggregate(queryToContents);
+ verifyNoMoreInteractions(contentAggregator);
+
+ verify(contentInjector).inject(asList(content1, content2), userMessage);
+ verify(contentInjector).inject(asList(content1, content2), (ChatMessage) userMessage);
+ verifyNoMoreInteractions(contentInjector);
+
+ verifyNoInteractions(executor);
+ }
+
@ParameterizedTest
@MethodSource("executors")
void should_not_augment_when_router_does_not_return_retrievers(Executor executor) {
@@ -150,13 +298,15 @@ class DefaultRetrievalAugmentorTest {
verifyNoMoreInteractions(queryRouter);
}
- static Stream executors() {
- return Stream.builder()
- .add(Arguments.of(Executors.newCachedThreadPool()))
- .add(Arguments.of(Executors.newFixedThreadPool(1)))
- .add(Arguments.of(Executors.newFixedThreadPool(2)))
- .add(Arguments.of(Executors.newFixedThreadPool(3)))
- .add(Arguments.of(Executors.newFixedThreadPool(4)))
+ static Stream executors() {
+ return Stream.builder()
+ .add(Executors.newCachedThreadPool())
+ .add(Executors.newFixedThreadPool(1))
+ .add(Executors.newFixedThreadPool(2))
+ .add(Executors.newFixedThreadPool(3))
+ .add(Executors.newFixedThreadPool(4))
+ .add(Runnable::run) // same thread executor
+ .add(null) // to use default Executor in DefaultRetrievalAugmentor
.build();
}