From 38d35ae96bbc1966693274a4b16b708379b07f10 Mon Sep 17 00:00:00 2001 From: kswang Date: Tue, 9 Nov 2021 19:14:40 +0800 Subject: [PATCH] optimize common thread pool --- mindspore/ccsrc/common/thread_pool.cc | 94 +++++++++++++------ mindspore/ccsrc/common/thread_pool.h | 14 +-- .../ccsrc/runtime/device/memory_scheduler.cc | 15 ++- 3 files changed, 80 insertions(+), 43 deletions(-) diff --git a/mindspore/ccsrc/common/thread_pool.cc b/mindspore/ccsrc/common/thread_pool.cc index 43e07e697ac..25a9543dc34 100644 --- a/mindspore/ccsrc/common/thread_pool.cc +++ b/mindspore/ccsrc/common/thread_pool.cc @@ -24,9 +24,10 @@ namespace mindspore { namespace common { #if ENABLE_D || ENABLE_GPU -const size_t kDeviceNum = 8; +constexpr size_t kDeviceNum = 8; #endif -const size_t kMaxThreadNum = 23; +constexpr size_t kMaxThreadNum = 23; +constexpr size_t kYieldThreshold = 1000; ThreadPool::ThreadPool() { size_t process_core_num = std::thread::hardware_concurrency() - 1; @@ -46,32 +47,47 @@ ThreadPool::ThreadPool() { } } -void ThreadPool::SyncRunLoop() { +void ThreadPool::SyncRunLoop(const std::shared_ptr &context) { + if (context == nullptr) { + return; + } + size_t yield_count = 0; while (true) { - Task task; - { - std::unique_lock lock(task_mutex_); - task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; }); - if (exit_run_) { - return; - } - task = task_queue_.front(); - task_queue_.pop(); + if (exit_run_) { + return; } + + if (!context->task) { + ++yield_count; + if (yield_count > kYieldThreshold) { + yield_count = 0; + std::unique_lock lock(context->mutex); + context->cond_var.wait(lock, [&context, this] { return context->task != nullptr || exit_run_; }); + } else { + std::this_thread::yield(); + continue; + } + } + + if (exit_run_) { + return; + } + try { + auto &task = *(context->task); task(); } catch (std::exception &e) { MsException::Instance().SetException(); } - { - std::unique_lock task_lock(task_mutex_); - task_finished_count_ = task_finished_count_ + 1; - } - finished_cond_var_.notify_one(); + yield_count = 0; + context->task = nullptr; } } bool ThreadPool::SyncRun(const std::vector &tasks) { + if (tasks.empty()) { + return true; + } if (tasks.size() == 1) { auto ret = tasks[0](); return ret == SUCCESS; @@ -85,20 +101,39 @@ bool ThreadPool::SyncRun(const std::vector &tasks) { if (task_num < max_thread_num_) { new_thread_num = task_num; } + contexts_.resize(new_thread_num); for (size_t i = thread_num; i < new_thread_num; ++i) { - sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this)); + contexts_[i] = std::make_shared(); + sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this, contexts_[i])); } } - - for (auto &task : tasks) { - std::lock_guard task_lock(task_mutex_); - task_queue_.push(task); - task_cond_var_.notify_one(); + if (contexts_.empty()) { + return true; } - { - std::unique_lock task_lock(task_mutex_); - finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; }); - task_finished_count_ = 0; + size_t used_thread_num = contexts_.size(); + if (task_num < used_thread_num) { + used_thread_num = task_num; + } + bool running = true; + size_t task_index = 0; + while (running) { + running = false; + for (size_t i = 0; i < used_thread_num; ++i) { + MS_EXCEPTION_IF_NULL(contexts_[i]); + auto &task_run = contexts_[i]->task; + if (task_run) { + running = true; + } else if (task_index < task_num) { + std::lock_guard task_lock(contexts_[i]->mutex); + contexts_[i]->task = &(tasks[task_index]); + contexts_[i]->cond_var.notify_one(); + running = true; + ++task_index; + } + } + if (running) { + std::this_thread::yield(); + } } return true; } @@ -114,7 +149,10 @@ void ThreadPool::ClearThreadPool() { return; } exit_run_ = true; - task_cond_var_.notify_all(); + for (auto &context : contexts_) { + MS_EXCEPTION_IF_NULL(context); + context->cond_var.notify_one(); + } for (auto &it : sync_run_threads_) { if (it.joinable()) { it.join(); diff --git a/mindspore/ccsrc/common/thread_pool.h b/mindspore/ccsrc/common/thread_pool.h index 6a6aa71791c..82ddeeffedf 100644 --- a/mindspore/ccsrc/common/thread_pool.h +++ b/mindspore/ccsrc/common/thread_pool.h @@ -35,6 +35,12 @@ namespace common { enum Status { FAIL = -1, SUCCESS = 0 }; using Task = std::function; +struct ThreadContext { + std::mutex mutex; + std::condition_variable cond_var; + const Task *task{nullptr}; +}; + class ThreadPool { public: ~ThreadPool(); @@ -47,17 +53,13 @@ class ThreadPool { private: ThreadPool(); - void SyncRunLoop(); + void SyncRunLoop(const std::shared_ptr &context); size_t max_thread_num_{1}; std::mutex pool_mtx_; std::atomic_bool exit_run_ = {false}; - std::queue task_queue_; - std::mutex task_mutex_; - std::condition_variable task_cond_var_; - size_t task_finished_count_{0}; - std::condition_variable finished_cond_var_; std::vector sync_run_threads_{}; + std::vector> contexts_; }; } // namespace common } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/memory_scheduler.cc b/mindspore/ccsrc/runtime/device/memory_scheduler.cc index b26db94b93a..e9fd7f54ffe 100644 --- a/mindspore/ccsrc/runtime/device/memory_scheduler.cc +++ b/mindspore/ccsrc/runtime/device/memory_scheduler.cc @@ -95,9 +95,7 @@ bool MemScheduler::PreCompute(void *stream) { for (auto &event : events) { MS_EXCEPTION_IF_NULL(event); MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type; - if (event->type == kInit) { - auto host_ptr = init_host_ptr_[event->key]; - MS_EXCEPTION_IF_NULL(host_ptr); + if (event->type == kInit || event->type == kMalloc) { auto priority = mem_priority_[event->key]; auto iter = high_priority_device_ptr_.find(event->key); if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) { @@ -112,12 +110,11 @@ bool MemScheduler::PreCompute(void *stream) { if (priority != kMemPriorityLow) { high_priority_device_ptr_[event->key] = device_ptr; } - mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); - mem_result_[event->key] = device_ptr; - } else if (event->type == kMalloc) { - auto device_ptr = mem_handler_->MallocDevice(event->mem_size); - if (device_ptr == nullptr) { - return false; + + if (event->type == kInit) { + auto host_ptr = init_host_ptr_[event->key]; + MS_EXCEPTION_IF_NULL(host_ptr); + mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream); } mem_result_[event->key] = device_ptr; } else if (event->type == kSwapIn) {