diff --git a/docs/docs/integrations/embedding-models/1-in-process.md b/docs/docs/integrations/embedding-models/1-in-process.md index b7e35ce95..b8827d56d 100644 --- a/docs/docs/integrations/embedding-models/1-in-process.md +++ b/docs/docs/integrations/embedding-models/1-in-process.md @@ -8,9 +8,6 @@ LangChain4j provides a few popular local embedding models packaged as maven depe They are powered by [ONNX runtime](https://onnxruntime.ai/docs/get-started/with-java.html) and are running in the same java process. -By default, embedding is parallelized using all available CPU cores. -Embedding using GPU is not supported yet. - Each model is provided in 2 flavours: original and quantized (has a `-q` suffix in maven artifact name and `Quantized` in the class name). For example: @@ -44,6 +41,24 @@ Embedding embedding = response.content(); The complete list of all embedding models can be found [here](https://github.com/langchain4j/langchain4j-embeddings). +## Parallelization + +By default, the embedding process is parallelized using all available CPU cores, +so each `TextSegment` is embedded in a separate thread. + +The parallelization is done by using an `Executor`. +By default, in-process embedding models use a cached thread pool +with the number of threads equal to the number of available processors. +Threads are cached for 1 second. + +You can provide a custom instance of the `Executor` when creating a model: +```java +Executor = ...; +EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(executor); +``` + +Embedding using GPU is not supported yet. + ## Custom models Many models (e.g., from [Hugging Face](https://huggingface.co/)) can be used, diff --git a/docs/docs/tutorials/5-ai-services.md b/docs/docs/tutorials/5-ai-services.md index 1e0a1651c..d1177aab8 100644 --- a/docs/docs/tutorials/5-ai-services.md +++ b/docs/docs/tutorials/5-ai-services.md @@ -474,6 +474,7 @@ RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder() .queryRouter(...) .contentAggregator(...) .contentInjector(...) + .executor(...) .build(); Assistant assistant = AiServices.builder(Assistant.class) diff --git a/docs/docs/tutorials/7-rag.md b/docs/docs/tutorials/7-rag.md index ff000204d..89eb35d02 100644 --- a/docs/docs/tutorials/7-rag.md +++ b/docs/docs/tutorials/7-rag.md @@ -743,6 +743,20 @@ More details are coming soon. More details are coming soon. +### Parallelization + +When there is only a single `Query` and a single `ContentRetriever`, +`DefaultRetrievalAugmentor` performs query routing and content retrieval in the same thread. +Otherwise, an `Executor` is used to parallelize the processing. +By default, a modified (`keepAliveTime` is 1 second instead of 60 seconds) `Executors.newCachedThreadPool()` +is used, but you can provide a custom `Executor` instance when creating the `DefaultRetrievalAugmentor`: +```java +DefaultRetrievalAugmentor.builder() + ... + .executor(executor) + .build; +``` + ## Examples - [Easy RAG](https://github.com/langchain4j/langchain4j-examples/blob/main/rag-examples/src/main/java/_1_easy/Easy_RAG_Example.java) diff --git a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java index da58e57d1..db228db14 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/rag/DefaultRetrievalAugmentor.java @@ -21,16 +21,14 @@ import org.slf4j.LoggerFactory; import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Executor; -import java.util.concurrent.Executors; +import java.util.concurrent.*; import static dev.langchain4j.internal.Utils.getOrDefault; -import static dev.langchain4j.internal.Utils.isNotNullOrBlank; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import static java.util.Collections.*; import static java.util.concurrent.CompletableFuture.allOf; import static java.util.concurrent.CompletableFuture.supplyAsync; +import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.stream.Collectors.*; /** @@ -90,8 +88,11 @@ import static java.util.stream.Collectors.*; * Nonetheless, you are encouraged to use one of the advanced ready-to-use implementations or create a custom one. *
*
- * By default, query routing and content retrieval are performed concurrently (for efficiency) - * using {@link Executors#newCachedThreadPool()}, but you can provide a custom {@link Executor}. + * When there is only a single {@link Query} and a single {@link ContentRetriever}, + * query routing and content retrieval are performed in the same thread. + * Otherwise, an {@link Executor} is used to parallelize the processing. + * By default, a modified (keepAliveTime is 1 second instead of 60 seconds) {@link Executors#newCachedThreadPool()} + * is used, but you can provide a custom {@link Executor} instance. * * @see DefaultQueryTransformer * @see DefaultQueryRouter @@ -118,7 +119,15 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor { this.queryRouter = ensureNotNull(queryRouter, "queryRouter"); this.contentAggregator = getOrDefault(contentAggregator, DefaultContentAggregator::new); this.contentInjector = getOrDefault(contentInjector, DefaultContentInjector::new); - this.executor = getOrDefault(executor, Executors::newCachedThreadPool); + this.executor = getOrDefault(executor, DefaultRetrievalAugmentor::createDefaultExecutor); + } + + private static ExecutorService createDefaultExecutor() { + return new ThreadPoolExecutor( + 0, Integer.MAX_VALUE, + 1, SECONDS, + new SynchronousQueue<>() + ); } /** @@ -142,20 +151,7 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor { Collection queries = queryTransformer.transform(originalQuery); logQueries(originalQuery, queries); - Map>>> queryToFutureContents = new ConcurrentHashMap<>(); - queries.forEach(query -> { - CompletableFuture>> futureContents = - supplyAsync(() -> { - Collection retrievers = queryRouter.route(query); - log(query, retrievers); - return retrievers; - }, - executor - ).thenCompose(retrievers -> retrieveFromAll(retrievers, query)); - queryToFutureContents.put(query, futureContents); - }); - - Map>> queryToContents = join(queryToFutureContents); + Map>> queryToContents = process(queries); List contents = contentAggregator.aggregate(queryToContents); log(queryToContents, contents); @@ -169,6 +165,39 @@ public class DefaultRetrievalAugmentor implements RetrievalAugmentor { .build(); } + private Map>> process(Collection queries) { + if (queries.size() == 1) { + Query query = queries.iterator().next(); + Collection retrievers = queryRouter.route(query); + if (retrievers.size() == 1) { + ContentRetriever contentRetriever = retrievers.iterator().next(); + List contents = contentRetriever.retrieve(query); + return singletonMap(query, singletonList(contents)); + } else if (retrievers.size() > 1) { + Collection> contents = retrieveFromAll(retrievers, query).join(); + return singletonMap(query, contents); + } else { + return emptyMap(); + } + } else if (queries.size() > 1) { + Map>>> queryToFutureContents = new ConcurrentHashMap<>(); + queries.forEach(query -> { + CompletableFuture>> futureContents = + supplyAsync(() -> { + Collection retrievers = queryRouter.route(query); + log(query, retrievers); + return retrievers; + }, + executor + ).thenCompose(retrievers -> retrieveFromAll(retrievers, query)); + queryToFutureContents.put(query, futureContents); + }); + return join(queryToFutureContents); + } else { + return emptyMap(); + } + } + private CompletableFuture>> retrieveFromAll(Collection retrievers, Query query) { List>> futureContents = retrievers.stream() diff --git a/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java b/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java index f9a673740..ca622a4f4 100644 --- a/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java +++ b/langchain4j-core/src/test/java/dev/langchain4j/rag/DefaultRetrievalAugmentorTest.java @@ -10,9 +10,10 @@ import dev.langchain4j.rag.query.Metadata; import dev.langchain4j.rag.query.Query; import dev.langchain4j.rag.query.router.DefaultQueryRouter; import dev.langchain4j.rag.query.router.QueryRouter; +import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer; import dev.langchain4j.rag.query.transformer.QueryTransformer; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import java.util.Collection; @@ -25,6 +26,7 @@ import java.util.stream.Stream; import static java.util.Arrays.asList; import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; import static org.assertj.core.api.Assertions.assertThat; @@ -34,7 +36,7 @@ class DefaultRetrievalAugmentorTest { @ParameterizedTest @MethodSource("executors") - void should_augment_user_message(Executor executor) { + void should_augment_user_message__multiple_queries_multiple_retrievers(Executor executor) { // given Query query1 = Query.from("query 1"); @@ -123,6 +125,152 @@ class DefaultRetrievalAugmentorTest { verifyNoMoreInteractions(contentInjector); } + @Test + void should_augment_user_message__single_query_multiple_retrievers() { + + // given + QueryTransformer queryTransformer = spy(new DefaultQueryTransformer()); + + Content content1 = Content.from("content 1"); + Content content2 = Content.from("content 2"); + ContentRetriever contentRetriever1 = spy(new TestContentRetriever(content1, content2)); + + Content content3 = Content.from("content 3"); + Content content4 = Content.from("content 4"); + ContentRetriever contentRetriever2 = spy(new TestContentRetriever(content3, content4)); + + QueryRouter queryRouter = spy(new DefaultQueryRouter(contentRetriever1, contentRetriever2)); + + ContentAggregator contentAggregator = spy(new TestContentAggregator()); + + ContentInjector contentInjector = spy(new TestContentInjector()); + + Executor executor = spy(new TestExecutor()); + + RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder() + .queryTransformer(queryTransformer) + .queryRouter(queryRouter) + .contentAggregator(contentAggregator) + .contentInjector(contentInjector) + .executor(executor) + .build(); + + UserMessage userMessage = UserMessage.from("query"); + + Metadata metadata = Metadata.from(userMessage, null, null); + + // when + UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata); + + // then + assertThat(augmented.singleText()).isEqualTo( + "query\n" + + "content 1\n" + + "content 2\n" + + "content 3\n" + + "content 4" + ); + + Query query = Query.from("query", metadata); + verify(queryTransformer).transform(query); + verifyNoMoreInteractions(queryTransformer); + + verify(queryRouter).route(query); + verifyNoMoreInteractions(queryRouter); + + verify(contentRetriever1).retrieve(query); + verifyNoMoreInteractions(contentRetriever1); + + verify(contentRetriever2).retrieve(query); + verifyNoMoreInteractions(contentRetriever2); + + Map>> queryToContents = new HashMap<>(); + queryToContents.put(query, asList( + asList(content1, content2), + asList(content3, content4) + + )); + verify(contentAggregator).aggregate(queryToContents); + verifyNoMoreInteractions(contentAggregator); + + verify(contentInjector).inject(asList(content1, content2, content3, content4), userMessage); + verify(contentInjector).inject(asList(content1, content2, content3, content4), (ChatMessage) userMessage); + verifyNoMoreInteractions(contentInjector); + + verify(executor, times(2)).execute(any()); + verifyNoMoreInteractions(executor); + } + + private static class TestExecutor implements Executor { + + @Override + public void execute(Runnable command) { + command.run(); + } + } + + @Test + void should_augment_user_message__single_query_single_retriever() { + + // given + QueryTransformer queryTransformer = spy(new DefaultQueryTransformer()); + + Content content1 = Content.from("content 1"); + Content content2 = Content.from("content 2"); + ContentRetriever contentRetriever = spy(new TestContentRetriever(content1, content2)); + + QueryRouter queryRouter = spy(new DefaultQueryRouter(contentRetriever)); + + ContentAggregator contentAggregator = spy(new TestContentAggregator()); + + ContentInjector contentInjector = spy(new TestContentInjector()); + + Executor executor = mock(Executor.class); + + RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder() + .queryTransformer(queryTransformer) + .queryRouter(queryRouter) + .contentAggregator(contentAggregator) + .contentInjector(contentInjector) + .executor(executor) + .build(); + + UserMessage userMessage = UserMessage.from("query"); + + Metadata metadata = Metadata.from(userMessage, null, null); + + // when + UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata); + + // then + assertThat(augmented.singleText()).isEqualTo( + "query\n" + + "content 1\n" + + "content 2" + ); + + Query query = Query.from("query", metadata); + verify(queryTransformer).transform(query); + verifyNoMoreInteractions(queryTransformer); + + verify(queryRouter).route(query); + verifyNoMoreInteractions(queryRouter); + + verify(contentRetriever).retrieve(query); + verifyNoMoreInteractions(contentRetriever); + + Map>> queryToContents = new HashMap<>(); + queryToContents.put(query, singletonList(asList(content1, content2))); + verify(contentAggregator).aggregate(queryToContents); + verifyNoMoreInteractions(contentAggregator); + + verify(contentInjector).inject(asList(content1, content2), userMessage); + verify(contentInjector).inject(asList(content1, content2), (ChatMessage) userMessage); + verifyNoMoreInteractions(contentInjector); + + verifyNoInteractions(executor); + } + @ParameterizedTest @MethodSource("executors") void should_not_augment_when_router_does_not_return_retrievers(Executor executor) { @@ -150,13 +298,15 @@ class DefaultRetrievalAugmentorTest { verifyNoMoreInteractions(queryRouter); } - static Stream executors() { - return Stream.builder() - .add(Arguments.of(Executors.newCachedThreadPool())) - .add(Arguments.of(Executors.newFixedThreadPool(1))) - .add(Arguments.of(Executors.newFixedThreadPool(2))) - .add(Arguments.of(Executors.newFixedThreadPool(3))) - .add(Arguments.of(Executors.newFixedThreadPool(4))) + static Stream executors() { + return Stream.builder() + .add(Executors.newCachedThreadPool()) + .add(Executors.newFixedThreadPool(1)) + .add(Executors.newFixedThreadPool(2)) + .add(Executors.newFixedThreadPool(3)) + .add(Executors.newFixedThreadPool(4)) + .add(Runnable::run) // same thread executor + .add(null) // to use default Executor in DefaultRetrievalAugmentor .build(); }