diff --git a/mlir/include/mlir/IR/Threading.h b/mlir/include/mlir/IR/Threading.h index 0d60e95b54c1..dd99039e298a 100644 --- a/mlir/include/mlir/IR/Threading.h +++ b/mlir/include/mlir/IR/Threading.h @@ -41,10 +41,7 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, // If multithreading is disabled or there is a small number of elements, // process the elements directly on this thread. - // FIXME: ThreadPool should allow work stealing to avoid deadlocks when - // scheduling work within a worker thread. - if (!context->isMultithreadingEnabled() || numElements <= 1 || - context->getThreadPool().isWorkerThread()) { + if (!context->isMultithreadingEnabled() || numElements <= 1) { for (; begin != end; ++begin) if (failed(func(*begin))) return failure(); @@ -70,16 +67,14 @@ LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin, // Otherwise, process the elements in parallel. llvm::ThreadPool &threadPool = context->getThreadPool(); + llvm::ThreadPoolTaskGroup tasksGroup(threadPool); size_t numActions = std::min(numElements, threadPool.getThreadCount()); - SmallVector> threadFutures; - threadFutures.reserve(numActions - 1); - for (unsigned i = 1; i < numActions; ++i) - threadFutures.emplace_back(threadPool.async(processFn)); - processFn(); - - // Wait for all of the threads to finish. - for (std::shared_future &future : threadFutures) - future.wait(); + for (unsigned i = 0; i < numActions; ++i) + tasksGroup.async(processFn); + // If the current thread is a worker thread from the pool, then waiting for + // the task group allows the current thread to also participate in processing + // tasks from the group, which avoid any deadlock/starvation. + tasksGroup.wait(); return failure(processingFailed); }