forked from mindspore-Ecosystem/mindspore
!21121 optimize thread pool
Merge pull request !21121 from kisnwang/optimize-thread-pool
This commit is contained in:
commit
8b6beddca3
|
@ -24,9 +24,10 @@
|
|||
namespace mindspore {
|
||||
namespace common {
|
||||
#if ENABLE_D || ENABLE_GPU
|
||||
const size_t kDeviceNum = 8;
|
||||
constexpr size_t kDeviceNum = 8;
|
||||
#endif
|
||||
const size_t kMaxThreadNum = 23;
|
||||
constexpr size_t kMaxThreadNum = 23;
|
||||
constexpr size_t kYieldThreshold = 1000;
|
||||
|
||||
ThreadPool::ThreadPool() {
|
||||
size_t process_core_num = std::thread::hardware_concurrency() - 1;
|
||||
|
@ -46,32 +47,47 @@ ThreadPool::ThreadPool() {
|
|||
}
|
||||
}
|
||||
|
||||
void ThreadPool::SyncRunLoop() {
|
||||
void ThreadPool::SyncRunLoop(const std::shared_ptr<ThreadContext> &context) {
|
||||
if (context == nullptr) {
|
||||
return;
|
||||
}
|
||||
size_t yield_count = 0;
|
||||
while (true) {
|
||||
Task task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(task_mutex_);
|
||||
task_cond_var_.wait(lock, [this] { return !task_queue_.empty() || exit_run_; });
|
||||
if (exit_run_) {
|
||||
return;
|
||||
}
|
||||
task = task_queue_.front();
|
||||
task_queue_.pop();
|
||||
if (exit_run_) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!context->task) {
|
||||
++yield_count;
|
||||
if (yield_count > kYieldThreshold) {
|
||||
yield_count = 0;
|
||||
std::unique_lock<std::mutex> lock(context->mutex);
|
||||
context->cond_var.wait(lock, [&context, this] { return context->task != nullptr || exit_run_; });
|
||||
} else {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (exit_run_) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
auto &task = *(context->task);
|
||||
task();
|
||||
} catch (std::exception &e) {
|
||||
MsException::Instance().SetException();
|
||||
}
|
||||
{
|
||||
std::unique_lock<std::mutex> task_lock(task_mutex_);
|
||||
task_finished_count_ = task_finished_count_ + 1;
|
||||
}
|
||||
finished_cond_var_.notify_one();
|
||||
yield_count = 0;
|
||||
context->task = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
|
||||
if (tasks.empty()) {
|
||||
return true;
|
||||
}
|
||||
if (tasks.size() == 1) {
|
||||
auto ret = tasks[0]();
|
||||
return ret == SUCCESS;
|
||||
|
@ -85,20 +101,39 @@ bool ThreadPool::SyncRun(const std::vector<Task> &tasks) {
|
|||
if (task_num < max_thread_num_) {
|
||||
new_thread_num = task_num;
|
||||
}
|
||||
contexts_.resize(new_thread_num);
|
||||
for (size_t i = thread_num; i < new_thread_num; ++i) {
|
||||
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this));
|
||||
contexts_[i] = std::make_shared<ThreadContext>();
|
||||
sync_run_threads_.emplace_back(std::thread(&ThreadPool::SyncRunLoop, this, contexts_[i]));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &task : tasks) {
|
||||
std::lock_guard<std::mutex> task_lock(task_mutex_);
|
||||
task_queue_.push(task);
|
||||
task_cond_var_.notify_one();
|
||||
if (contexts_.empty()) {
|
||||
return true;
|
||||
}
|
||||
{
|
||||
std::unique_lock<std::mutex> task_lock(task_mutex_);
|
||||
finished_cond_var_.wait(task_lock, [this, task_num] { return task_num == task_finished_count_; });
|
||||
task_finished_count_ = 0;
|
||||
size_t used_thread_num = contexts_.size();
|
||||
if (task_num < used_thread_num) {
|
||||
used_thread_num = task_num;
|
||||
}
|
||||
bool running = true;
|
||||
size_t task_index = 0;
|
||||
while (running) {
|
||||
running = false;
|
||||
for (size_t i = 0; i < used_thread_num; ++i) {
|
||||
MS_EXCEPTION_IF_NULL(contexts_[i]);
|
||||
auto &task_run = contexts_[i]->task;
|
||||
if (task_run) {
|
||||
running = true;
|
||||
} else if (task_index < task_num) {
|
||||
std::lock_guard<std::mutex> task_lock(contexts_[i]->mutex);
|
||||
contexts_[i]->task = &(tasks[task_index]);
|
||||
contexts_[i]->cond_var.notify_one();
|
||||
running = true;
|
||||
++task_index;
|
||||
}
|
||||
}
|
||||
if (running) {
|
||||
std::this_thread::yield();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -114,7 +149,10 @@ void ThreadPool::ClearThreadPool() {
|
|||
return;
|
||||
}
|
||||
exit_run_ = true;
|
||||
task_cond_var_.notify_all();
|
||||
for (auto &context : contexts_) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
context->cond_var.notify_one();
|
||||
}
|
||||
for (auto &it : sync_run_threads_) {
|
||||
if (it.joinable()) {
|
||||
it.join();
|
||||
|
|
|
@ -35,6 +35,12 @@ namespace common {
|
|||
enum Status { FAIL = -1, SUCCESS = 0 };
|
||||
using Task = std::function<int()>;
|
||||
|
||||
struct ThreadContext {
|
||||
std::mutex mutex;
|
||||
std::condition_variable cond_var;
|
||||
const Task *task{nullptr};
|
||||
};
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
~ThreadPool();
|
||||
|
@ -47,17 +53,13 @@ class ThreadPool {
|
|||
|
||||
private:
|
||||
ThreadPool();
|
||||
void SyncRunLoop();
|
||||
void SyncRunLoop(const std::shared_ptr<ThreadContext> &context);
|
||||
|
||||
size_t max_thread_num_{1};
|
||||
std::mutex pool_mtx_;
|
||||
std::atomic_bool exit_run_ = {false};
|
||||
std::queue<Task> task_queue_;
|
||||
std::mutex task_mutex_;
|
||||
std::condition_variable task_cond_var_;
|
||||
size_t task_finished_count_{0};
|
||||
std::condition_variable finished_cond_var_;
|
||||
std::vector<std::thread> sync_run_threads_{};
|
||||
std::vector<std::shared_ptr<ThreadContext>> contexts_;
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -95,9 +95,7 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
for (auto &event : events) {
|
||||
MS_EXCEPTION_IF_NULL(event);
|
||||
MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
|
||||
if (event->type == kInit) {
|
||||
auto host_ptr = init_host_ptr_[event->key];
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
if (event->type == kInit || event->type == kMalloc) {
|
||||
auto priority = mem_priority_[event->key];
|
||||
auto iter = high_priority_device_ptr_.find(event->key);
|
||||
if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
|
||||
|
@ -112,12 +110,11 @@ bool MemScheduler::PreCompute(void *stream) {
|
|||
if (priority != kMemPriorityLow) {
|
||||
high_priority_device_ptr_[event->key] = device_ptr;
|
||||
}
|
||||
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
||||
mem_result_[event->key] = device_ptr;
|
||||
} else if (event->type == kMalloc) {
|
||||
auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
|
||||
if (device_ptr == nullptr) {
|
||||
return false;
|
||||
|
||||
if (event->type == kInit) {
|
||||
auto host_ptr = init_host_ptr_[event->key];
|
||||
MS_EXCEPTION_IF_NULL(host_ptr);
|
||||
mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
|
||||
}
|
||||
mem_result_[event->key] = device_ptr;
|
||||
} else if (event->type == kSwapIn) {
|
||||
|
|
Loading…
Reference in New Issue