forked from OSchip/llvm-project
Enable the use of ThreadPoolTaskGroup in MLIR threading helper to enable nested parallelism
The LLVM ThreadPool recently got the addition of the concept of ThreadPoolTaskGroup: this is a way to "partition" the threadpool into a group of tasks and enable nested parallelism through this grouping at every level of nesting. We make use of this feature in MLIR threading abstraction to fix a long lasting TODO and enable nested parallelism. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D124902
This commit is contained in:
parent
c5ea8d509c
commit
072e0aabbc
|
@ -41,10 +41,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
|
||||||
|
|
||||||
// If multithreading is disabled or there is a small number of elements,
|
// If multithreading is disabled or there is a small number of elements,
|
||||||
// process the elements directly on this thread.
|
// process the elements directly on this thread.
|
||||||
// FIXME: ThreadPool should allow work stealing to avoid deadlocks when
|
if (!context->isMultithreadingEnabled() || numElements <= 1) {
|
||||||
// scheduling work within a worker thread.
|
|
||||||
if (!context->isMultithreadingEnabled() || numElements <= 1 ||
|
|
||||||
context->getThreadPool().isWorkerThread()) {
|
|
||||||
for (; begin != end; ++begin)
|
for (; begin != end; ++begin)
|
||||||
if (failed(func(*begin)))
|
if (failed(func(*begin)))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -70,16 +67,14 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
|
||||||
|
|
||||||
// Otherwise, process the elements in parallel.
|
// Otherwise, process the elements in parallel.
|
||||||
llvm::ThreadPool &threadPool = context->getThreadPool();
|
llvm::ThreadPool &threadPool = context->getThreadPool();
|
||||||
|
llvm::ThreadPoolTaskGroup tasksGroup(threadPool);
|
||||||
size_t numActions = std::min(numElements, threadPool.getThreadCount());
|
size_t numActions = std::min(numElements, threadPool.getThreadCount());
|
||||||
SmallVector<std::shared_future<void>> threadFutures;
|
for (unsigned i = 0; i < numActions; ++i)
|
||||||
threadFutures.reserve(numActions - 1);
|
tasksGroup.async(processFn);
|
||||||
for (unsigned i = 1; i < numActions; ++i)
|
// If the current thread is a worker thread from the pool, then waiting for
|
||||||
threadFutures.emplace_back(threadPool.async(processFn));
|
// the task group allows the current thread to also participate in processing
|
||||||
processFn();
|
// tasks from the group, which avoid any deadlock/starvation.
|
||||||
|
tasksGroup.wait();
|
||||||
// Wait for all of the threads to finish.
|
|
||||||
for (std::shared_future<void> &future : threadFutures)
|
|
||||||
future.wait();
|
|
||||||
return failure(processingFailed);
|
return failure(processingFailed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue