ThreadPool: grow the pool only as needed

On my 96-core cloudtop 'machine', it seems unnecessary to always start
96 threads upfront... particularly as the ThreadPool is created even
with -mlir-disable-threading. Things like the resuling spew in GDB and
the obfuscated output of `(gdb) info threads` are my motivation here,
but it probably also doesn't hurt for at least some efficiency metrics to
avoid creating many threads upfront.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D115019
This commit is contained in:
Benoit Jacob 2021-12-03 21:40:28 +00:00 committed by Mehdi Amini
parent 93a20ecee4
commit 728b982bb2
2 changed files with 69 additions and 46 deletions

View File

@ -40,7 +40,8 @@ public:
/// execution resources (threads, cores, CPUs)
/// Defaults to using the maximum execution resources in the system, but
/// accounting for the affinity mask.
ThreadPool(ThreadPoolStrategy S = hardware_concurrency());
ThreadPool(ThreadPoolStrategy S = hardware_concurrency())
: Strategy(S), MaxThreadCount(S.compute_thread_count()) {}
/// Blocking destructor: the pool will wait for all the threads to complete.
~ThreadPool();
@ -65,7 +66,10 @@ public:
/// It is an error to try to add new tasks while blocking on this call.
void wait();
unsigned getThreadCount() const { return ThreadCount; }
// TODO: misleading legacy name warning!
// Returns the maximum number of worker threads in the pool, not the current
// number of threads!
unsigned getThreadCount() const { return MaxThreadCount; }
/// Returns true if the current thread is a worker thread of this thread pool.
bool isWorkerThread() const;
@ -115,6 +119,7 @@ private:
// Don't allow enqueueing after disabling the pool
assert(EnableFlag && "Queuing a thread during ThreadPool destruction");
Tasks.push(std::move(R.first));
grow();
}
QueueCondition.notify_one();
return R.second.share();
@ -130,6 +135,21 @@ private:
#endif
}
#if LLVM_ENABLE_THREADS
// Maybe create a new thread and add it to Threads.
//
// Requirements:
// * this->QueueLock should be owned by the calling thread prior to
// calling this function. It will neither lock it nor unlock it.
// Calling this function without owning QueueLock would result in data
// races as this function reads Tasks and ActiveThreads.
// * this->Tasks should be populated with any pending tasks. This function
// uses Tasks.size() to determine whether it needs to create a new thread.
// * this->ActiveThreads should be up to date as it is also used to
// determine whether to create a new thread.
void grow();
#endif
/// Threads in flight
std::vector<llvm::thread> Threads;
@ -137,7 +157,7 @@ private:
std::queue<std::function<void()>> Tasks;
/// Locking and signaling for accessing the Tasks queue.
std::mutex QueueLock;
mutable std::mutex QueueLock;
std::condition_variable QueueCondition;
/// Signaling for job completion
@ -151,7 +171,10 @@ private:
bool EnableFlag = true;
#endif
unsigned ThreadCount;
const ThreadPoolStrategy Strategy;
/// Maximum number of threads to potentially grow this pool to.
const unsigned MaxThreadCount;
};
}

View File

@ -20,50 +20,49 @@ using namespace llvm;
#if LLVM_ENABLE_THREADS
ThreadPool::ThreadPool(ThreadPoolStrategy S)
: ThreadCount(S.compute_thread_count()) {
// Create ThreadCount threads that will loop forever, wait on QueueCondition
// for tasks to be queued or the Pool to be destroyed.
Threads.reserve(ThreadCount);
for (unsigned ThreadID = 0; ThreadID < ThreadCount; ++ThreadID) {
Threads.emplace_back([S, ThreadID, this] {
S.apply_thread_strategy(ThreadID);
while (true) {
std::function<void()> Task;
{
std::unique_lock<std::mutex> LockGuard(QueueLock);
// Wait for tasks to be pushed in the queue
QueueCondition.wait(LockGuard,
[&] { return !EnableFlag || !Tasks.empty(); });
// Exit condition
if (!EnableFlag && Tasks.empty())
return;
// Yeah, we have a task, grab it and release the lock on the queue
void ThreadPool::grow() {
if (Threads.size() >= MaxThreadCount)
return; // Already hit the max thread pool size.
if (ActiveThreads + Tasks.size() <= Threads.size())
return; // We have enough threads for now.
int ThreadID = Threads.size();
Threads.emplace_back([this, ThreadID] {
Strategy.apply_thread_strategy(ThreadID);
while (true) {
std::function<void()> Task;
{
std::unique_lock<std::mutex> LockGuard(QueueLock);
// Wait for tasks to be pushed in the queue
QueueCondition.wait(LockGuard,
[&] { return !EnableFlag || !Tasks.empty(); });
// Exit condition
if (!EnableFlag && Tasks.empty())
return;
// Yeah, we have a task, grab it and release the lock on the queue
// We first need to signal that we are active before popping the queue
// in order for wait() to properly detect that even if the queue is
// empty, there is still a task in flight.
++ActiveThreads;
Task = std::move(Tasks.front());
Tasks.pop();
}
// Run the task we just grabbed
Task();
bool Notify;
{
// Adjust `ActiveThreads`, in case someone waits on ThreadPool::wait()
std::lock_guard<std::mutex> LockGuard(QueueLock);
--ActiveThreads;
Notify = workCompletedUnlocked();
}
// Notify task completion if this is the last active thread, in case
// someone waits on ThreadPool::wait().
if (Notify)
CompletionCondition.notify_all();
// We first need to signal that we are active before popping the queue
// in order for wait() to properly detect that even if the queue is
// empty, there is still a task in flight.
++ActiveThreads;
Task = std::move(Tasks.front());
Tasks.pop();
}
});
}
// Run the task we just grabbed
Task();
bool Notify;
{
// Adjust `ActiveThreads`, in case someone waits on ThreadPool::wait()
std::lock_guard<std::mutex> LockGuard(QueueLock);
--ActiveThreads;
Notify = workCompletedUnlocked();
}
// Notify task completion if this is the last active thread, in case
// someone waits on ThreadPool::wait().
if (Notify)
CompletionCondition.notify_all();
}
});
}
void ThreadPool::wait() {
@ -73,6 +72,7 @@ void ThreadPool::wait() {
}
bool ThreadPool::isWorkerThread() const {
std::unique_lock<std::mutex> LockGuard(QueueLock);
llvm::thread::id CurrentThreadId = llvm::this_thread::get_id();
for (const llvm::thread &Thread : Threads)
if (CurrentThreadId == Thread.get_id())