Enable the use of ThreadPoolTaskGroup in MLIR threading helper to enable nested parallelism

The LLVM ThreadPool recently got the addition of the concept of
ThreadPoolTaskGroup: this is a way to "partition" the threadpool
into a group of tasks and enable nested parallelism through this
grouping at every level of nesting.
We make use of this feature in MLIR threading abstraction to fix a long
lasting TODO and enable nested parallelism.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D124902
This commit is contained in:
Mehdi Amini 2022-05-06 19:38:49 +00:00
parent c5ea8d509c
commit 072e0aabbc
1 changed files with 8 additions and 13 deletions

View File

@ -41,10 +41,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
// If multithreading is disabled or there is a small number of elements,
// process the elements directly on this thread.
// FIXME: ThreadPool should allow work stealing to avoid deadlocks when
// scheduling work within a worker thread.
if (!context->isMultithreadingEnabled() || numElements <= 1 ||
context->getThreadPool().isWorkerThread()) {
if (!context->isMultithreadingEnabled() || numElements <= 1) {
for (; begin != end; ++begin)
if (failed(func(*begin)))
return failure();
@ -70,16 +67,14 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
// Otherwise, process the elements in parallel.
llvm::ThreadPool &threadPool = context->getThreadPool();
llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
size_t numActions = std::min(numElements, threadPool.getThreadCount());
SmallVector<std::shared_future<void>> threadFutures;
threadFutures.reserve(numActions - 1);
for (unsigned i = 1; i < numActions; ++i)
threadFutures.emplace_back(threadPool.async(processFn));
processFn();
// Wait for all of the threads to finish.
for (std::shared_future<void> &future : threadFutures)
future.wait();
for (unsigned i = 0; i < numActions; ++i)
tasksGroup.async(processFn);
// If the current thread is a worker thread from the pool, then waiting for
// the task group allows the current thread to also participate in processing
// tasks from the group, which avoid any deadlock/starvation.
tasksGroup.wait();
return failure(processingFailed);
}