!21121 optimize thread pool

Merge pull request !21121 from kisnwang/optimize-thread-pool
This commit is contained in:
i-robot 2021-11-10 06:21:18 +00:00 committed by Gitee
commit 8b6beddca3
3 changed files with 80 additions and 43 deletions

View File

@ -24,9 +24,10 @@
namespace mindspore {
namespace common {
#if ENABLE_D || ENABLE_GPU
const size_t kDeviceNum = 8;
constexpr size_t kDeviceNum = 8;
#endif
const size_t kMaxThreadNum = 23;
constexpr size_t kMaxThreadNum = 23;
constexpr size_t kYieldThreshold = 1000;
ThreadPool::ThreadPool() {
size_t process_core_num = std::thread::hardware_concurrency() - 1;
@ -46,32 +47,47 @@ ThreadPool::ThreadPool() {
}
}
void ThreadPool::SyncRunLoop() {
void ThreadPool::SyncRunLoop(const std::shared_ptr<ThreadContext> &context) {
if (context == nullptr) {
return;
}
size_t yield_count = 0;
while (true) {
Task task;
{
std::unique_lock<std::mutex> lock(task_mutex_);
task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; });
if (exit_run_) {
return;
}
task = task_queue_.front();
task_queue_.pop();
if (exit_run_) {
return;
}
if (!context->task) {
++yield_count;
if (yield_count > kYieldThreshold) {
yield_count = 0;
std::unique_lock<std::mutex> lock(context->mutex);
context->cond_var.wait(lock, [&context, this] { return context->task != nullptr || exit_run_; });
} else {
std::this_thread::yield();
continue;
}
}
if (exit_run_) {
return;
}
try {
auto &task = *(context->task);
task();
} catch (std::exception &e) {
MsException::Instance().SetException();
}
{
std::unique_lock<std::mutex> task_lock(task_mutex_);
task_finished_count_ = task_finished_count_ + 1;
}
finished_cond_var_.notify_one();
yield_count = 0;
context->task = nullptr;
}
}
bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
if (tasks.empty()) {
return true;
}
if (tasks.size() == 1) {
auto ret = tasks[0]();
return ret == SUCCESS;
@ -85,20 +101,39 @@ bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
if (task_num < max_thread_num_) {
new_thread_num = task_num;
}
contexts_.resize(new_thread_num);
for (size_t i = thread_num; i < new_thread_num; ++i) {
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this));
contexts_[i] = std::make_shared<ThreadContext>();
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this, contexts_[i]));
}
}
for (auto &task : tasks) {
std::lock_guard<std::mutex> task_lock(task_mutex_);
task_queue_.push(task);
task_cond_var_.notify_one();
if (contexts_.empty()) {
return true;
}
{
std::unique_lock<std::mutex> task_lock(task_mutex_);
finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; });
task_finished_count_ = 0;
size_t used_thread_num = contexts_.size();
if (task_num < used_thread_num) {
used_thread_num = task_num;
}
bool running = true;
size_t task_index = 0;
while (running) {
running = false;
for (size_t i = 0; i < used_thread_num; ++i) {
MS_EXCEPTION_IF_NULL(contexts_[i]);
auto &task_run = contexts_[i]->task;
if (task_run) {
running = true;
} else if (task_index < task_num) {
std::lock_guard<std::mutex> task_lock(contexts_[i]->mutex);
contexts_[i]->task = &(tasks[task_index]);
contexts_[i]->cond_var.notify_one();
running = true;
++task_index;
}
}
if (running) {
std::this_thread::yield();
}
}
return true;
}
@ -114,7 +149,10 @@ void ThreadPool::ClearThreadPool() {
return;
}
exit_run_ = true;
task_cond_var_.notify_all();
for (auto &context : contexts_) {
MS_EXCEPTION_IF_NULL(context);
context->cond_var.notify_one();
}
for (auto &it : sync_run_threads_) {
if (it.joinable()) {
it.join();

View File

@ -35,6 +35,12 @@ namespace common {
enum Status { FAIL = -1, SUCCESS = 0 };
using Task = std::function<int()>;
struct ThreadContext {
std::mutex mutex;
std::condition_variable cond_var;
const Task *task{nullptr};
};
class ThreadPool {
public:
~ThreadPool();
@ -47,17 +53,13 @@ class ThreadPool {
private:
ThreadPool();
void SyncRunLoop();
void SyncRunLoop(const std::shared_ptr<ThreadContext> &context);
size_t max_thread_num_{1};
std::mutex pool_mtx_;
std::atomic_bool exit_run_ = {false};
std::queue<Task> task_queue_;
std::mutex task_mutex_;
std::condition_variable task_cond_var_;
size_t task_finished_count_{0};
std::condition_variable finished_cond_var_;
std::vector<std::thread> sync_run_threads_{};
std::vector<std::shared_ptr<ThreadContext>> contexts_;
};
} // namespace common
} // namespace mindspore

View File

@ -95,9 +95,7 @@ bool MemScheduler::PreCompute(void *stream) {
for (auto &event : events) {
MS_EXCEPTION_IF_NULL(event);
MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
if (event->type == kInit) {
auto host_ptr = init_host_ptr_[event->key];
MS_EXCEPTION_IF_NULL(host_ptr);
if (event->type == kInit || event->type == kMalloc) {
auto priority = mem_priority_[event->key];
auto iter = high_priority_device_ptr_.find(event->key);
if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
@ -112,12 +110,11 @@ bool MemScheduler::PreCompute(void *stream) {
if (priority != kMemPriorityLow) {
high_priority_device_ptr_[event->key] = device_ptr;
}
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
mem_result_[event->key] = device_ptr;
} else if (event->type == kMalloc) {
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
if (device_ptr == nullptr) {
return false;
if (event->type == kInit) {
auto host_ptr = init_host_ptr_[event->key];
MS_EXCEPTION_IF_NULL(host_ptr);
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
}
mem_result_[event->key] = device_ptr;
} else if (event->type == kSwapIn) {