Enable the use of ThreadPoolTaskGroup in MLIR threading helper to enable nested parallelism

The LLVM ThreadPool recently got the addition of the concept of
ThreadPoolTaskGroup: this is a way to "partition" the threadpool
into a group of tasks and enable nested parallelism through this
grouping at every level of nesting.
We make use of this feature in MLIR threading abstraction to fix a long
lasting TODO and enable nested parallelism.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D124902
This commit is contained in:
Mehdi Amini 2022-05-06 19:38:49 +00:00
parent c5ea8d509c
commit 072e0aabbc
1 changed files with 8 additions and 13 deletions

View File

@ -41,10 +41,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
// If multithreading is disabled or there is a small number of elements, // If multithreading is disabled or there is a small number of elements,
// process the elements directly on this thread. // process the elements directly on this thread.
// FIXME: ThreadPool should allow work stealing to avoid deadlocks when if (!context->isMultithreadingEnabled() || numElements <= 1) {
// scheduling work within a worker thread.
if (!context->isMultithreadingEnabled() || numElements <= 1 ||
context->getThreadPool().isWorkerThread()) {
for (; begin != end; ++begin) for (; begin != end; ++begin)
if (failed(func(*begin))) if (failed(func(*begin)))
return failure(); return failure();
@ -70,16 +67,14 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
// Otherwise, process the elements in parallel. // Otherwise, process the elements in parallel.
llvm::ThreadPool &threadPool = context->getThreadPool(); llvm::ThreadPool &threadPool = context->getThreadPool();
llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
size_t numActions = std::min(numElements, threadPool.getThreadCount()); size_t numActions = std::min(numElements, threadPool.getThreadCount());
SmallVector<std::shared_future<void>> threadFutures; for (unsigned i = 0; i < numActions; ++i)
threadFutures.reserve(numActions - 1); tasksGroup.async(processFn);
for (unsigned i = 1; i < numActions; ++i) // If the current thread is a worker thread from the pool, then waiting for
threadFutures.emplace_back(threadPool.async(processFn)); // the task group allows the current thread to also participate in processing
processFn(); // tasks from the group, which avoid any deadlock/starvation.
tasksGroup.wait();
// Wait for all of the threads to finish.
for (std::shared_future<void> &future : threadFutures)
future.wait();
return failure(processingFailed); return failure(processingFailed);
} }