acc worker thread and mkl thread

This commit is contained in:
fangzehua 2022-05-10 11:01:18 +08:00
parent b1023addba
commit f8268506da
5 changed files with 70 additions and 36 deletions

View File

@ -183,34 +183,44 @@ void MKLCpuKernelMod::ExecutePrimitive() {
MS_EXCEPTION_IF_NULL(primitive_); MS_EXCEPTION_IF_NULL(primitive_);
#ifdef USE_MS_THREADPOOL_FOR_DNNL #ifdef USE_MS_THREADPOOL_FOR_DNNL
// add auto search // add auto search
const size_t MAX_POW = 6; const std::vector<size_t> kSearchThreadList{4, 8, 16, 24, 32};
const size_t AVG_COUNT = 5; const size_t kAvgCount = 5;
const size_t DIFF = 2; const size_t kDiff = 2;
size_t current_pow = parallel_search_info_.search_count / AVG_COUNT; size_t current_pow = parallel_search_info_.search_count / kAvgCount;
int current_thread_nums = static_cast<int>(std::pow(2.0f, current_pow));
auto mkl_pool = dynamic_cast<mkl_threadpool *>(mkl_threadpool_.get()); auto mkl_pool = dynamic_cast<mkl_threadpool *>(mkl_threadpool_.get());
if (current_pow < MAX_POW) { if (current_pow < kSearchThreadList.size()) {
if (parallel_search_info_.search_count % AVG_COUNT == 0) { if (parallel_search_info_.search_count % kAvgCount == 0) {
parallel_search_info_.tmp_sum_cost_time = 0; parallel_search_info_.tmp_sum_cost_time = 0;
} }
double start_time = GetTime(); double start_time = GetTime();
int current_thread_nums = kSearchThreadList[current_pow];
mkl_pool->set_num_threads(current_thread_nums); mkl_pool->set_num_threads(current_thread_nums);
MS_LOG(DEBUG) << "begin to invoke primitive::execute"; MS_LOG(DEBUG) << "begin to invoke primitive::execute";
primitive_->execute(stream_, arguments_); primitive_->execute(stream_, arguments_);
MS_LOG(DEBUG) << "end to invoke primitive::execute"; MS_LOG(DEBUG) << "end to invoke primitive::execute";
double cost_time = GetTime() - start_time; double cost_time = GetTime() - start_time;
parallel_search_info_.tmp_sum_cost_time += cost_time; // skip the first step to warm up.
if (parallel_search_info_.search_count != 0) {
parallel_search_info_.tmp_sum_cost_time += cost_time;
}
parallel_search_info_.search_count++; parallel_search_info_.search_count++;
if (parallel_search_info_.search_count % AVG_COUNT == 0) { if (parallel_search_info_.search_count % kAvgCount == 0) {
if (parallel_search_info_.min_cost_time > parallel_search_info_.tmp_sum_cost_time) { double avg_time = 0;
parallel_search_info_.min_cost_time = parallel_search_info_.tmp_sum_cost_time; // first avg will skip the first step
if (parallel_search_info_.search_count / kAvgCount == 0) {
avg_time = parallel_search_info_.tmp_sum_cost_time / (kAvgCount - 1);
} else {
avg_time = parallel_search_info_.tmp_sum_cost_time / kAvgCount;
}
if (parallel_search_info_.min_cost_time > avg_time) {
parallel_search_info_.min_cost_time = avg_time;
parallel_search_info_.best_pow = current_pow; parallel_search_info_.best_pow = current_pow;
} else if (current_pow - parallel_search_info_.best_pow >= DIFF) { } else if (current_pow - parallel_search_info_.best_pow >= kDiff) {
parallel_search_info_.search_count = AVG_COUNT * MAX_POW; parallel_search_info_.search_count = kAvgCount * kSearchThreadList.size();
} }
} }
} else { } else {
int best_thread_nums = static_cast<int>(std::pow(2.0f, parallel_search_info_.best_pow)); int best_thread_nums = kSearchThreadList[parallel_search_info_.best_pow];
mkl_pool->set_num_threads(best_thread_nums); mkl_pool->set_num_threads(best_thread_nums);
MS_LOG(DEBUG) << "begin to invoke primitive::execute"; MS_LOG(DEBUG) << "begin to invoke primitive::execute";
primitive_->execute(stream_, arguments_); primitive_->execute(stream_, arguments_);

View File

@ -134,14 +134,19 @@ class mkl_threadpool : public dnnl::threadpool_interop::threadpool_iface {
private: private:
ActorThreadPool *tp_; ActorThreadPool *tp_;
int thread_num_{8}; int thread_num_{8};
bool first_parallel{true};
public: public:
explicit mkl_threadpool(ActorThreadPool *tp) { tp_ = tp; } explicit mkl_threadpool(ActorThreadPool *tp) { tp_ = tp; }
void set_num_threads(int num) { thread_num_ = num; } void set_num_threads(int num) { thread_num_ = num; }
int get_num_threads() const override { return std::min(SizeToInt(tp_->GetKernelThreadNum()), thread_num_); } int get_num_threads() const override { return std::min(SizeToInt(tp_->GetKernelThreadNum()), thread_num_); }
bool get_in_parallel() const override { return false; } bool get_in_parallel() const override { return !first_parallel; }
uint64_t get_flags() const override { return 0; } uint64_t get_flags() const override { return 0; }
void parallel_for(int n, const std::function<void(int, int)> &fn) override { void parallel_for(int n, const std::function<void(int, int)> &fn) override {
bool need_change_flag = first_parallel ? true : false;
if (need_change_flag) {
first_parallel = false;
}
int nthr = get_num_threads(); int nthr = get_num_threads();
int n_jobs = std::min(n, nthr); int n_jobs = std::min(n, nthr);
auto func = [&, n_jobs](void *, int i, float, float) { auto func = [&, n_jobs](void *, int i, float, float) {
@ -149,6 +154,9 @@ class mkl_threadpool : public dnnl::threadpool_interop::threadpool_iface {
return 0; return 0;
}; };
(void)tp_->ParallelLaunch(func, nullptr, n_jobs); (void)tp_->ParallelLaunch(func, nullptr, n_jobs);
if (need_change_flag) {
first_parallel = true;
}
} }
}; };
#endif #endif

View File

@ -106,10 +106,13 @@ bool Worker::RunLocalKernelTask() {
if (task == nullptr) { if (task == nullptr) {
return false; return false;
} }
int task_id = task_id_.load(std::memory_order_consume); int task_id_start = task_id_start_.load(std::memory_order_consume);
task->status |= task->func(task->content, task_id, lhs_scale_, rhs_scale_); int task_id_end = task_id_end_.load(std::memory_order_consume);
for (int i = task_id_start; i < task_id_end; ++i) {
task->status |= task->func(task->content, i, lhs_scale_, rhs_scale_);
}
task_.store(nullptr, std::memory_order_relaxed); task_.store(nullptr, std::memory_order_relaxed);
(void)++task->finished; task->finished += task_id_end - task_id_start;
return true; return true;
} }
@ -134,11 +137,12 @@ void Worker::set_scale(float lhs_scale, float rhs_scale) {
rhs_scale_ = rhs_scale; rhs_scale_ = rhs_scale;
} }
void Worker::Active(Task *task, int task_id) { void Worker::Active(Task *task, int task_id_start, int task_id_end) {
{ {
std::lock_guard<std::mutex> _l(mutex_); std::lock_guard<std::mutex> _l(mutex_);
THREAD_TEST_TRUE(task_ == nullptr); THREAD_TEST_TRUE(task_ == nullptr);
task_id_.store(task_id, std::memory_order_relaxed); task_id_start_.store(task_id_start, std::memory_order_relaxed);
task_id_end_.store(task_id_end, std::memory_order_relaxed);
task_.store(task, std::memory_order_release); task_.store(task, std::memory_order_release);
status_ = kThreadBusy; status_ = kThreadBusy;
} }
@ -236,6 +240,7 @@ int ThreadPool::SyncRunFunc(const Func &func, Content content, int start, int en
void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const { void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const {
int sum_frequency = 0; int sum_frequency = 0;
std::vector<Worker *> assigned; std::vector<Worker *> assigned;
assigned.reserve(task_num);
int num = static_cast<int>(workers_.size()) - 1; int num = static_cast<int>(workers_.size()) - 1;
int offset = 0; int offset = 0;
bool use_curr = (curr != nullptr) ? curr->get_task_free() : false; bool use_curr = (curr != nullptr) ? curr->get_task_free() : false;
@ -256,13 +261,9 @@ void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const {
} }
} }
// when there are not enough free threads,
// distribute other tasks to the master thread
if (use_curr) { if (use_curr) {
for (; count < task_num; ++count) { assigned.push_back(curr);
assigned.push_back(curr); sum_frequency += curr->frequency();
sum_frequency += curr->frequency();
}
} else if (assigned.size() != static_cast<size_t>(task_num)) { } else if (assigned.size() != static_cast<size_t>(task_num)) {
CalculateScales(assigned, sum_frequency); CalculateScales(assigned, sum_frequency);
ActiveWorkers(assigned, task, assigned.size(), curr); ActiveWorkers(assigned, task, assigned.size(), curr);
@ -271,7 +272,7 @@ void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const {
} }
CalculateScales(assigned, sum_frequency); CalculateScales(assigned, sum_frequency);
ActiveWorkers(assigned, task, assigned.size(), curr); ActiveWorkers(assigned, task, task_num, curr);
} }
void ThreadPool::CalculateScales(const std::vector<Worker *> &assigned, int sum_frequency) const { void ThreadPool::CalculateScales(const std::vector<Worker *> &assigned, int sum_frequency) const {
@ -292,12 +293,26 @@ void ThreadPool::CalculateScales(const std::vector<Worker *> &assigned, int sum_
void ThreadPool::ActiveWorkers(const std::vector<Worker *> &workers, Task *task, int task_num, void ThreadPool::ActiveWorkers(const std::vector<Worker *> &workers, Task *task, int task_num,
const Worker *curr) const { const Worker *curr) const {
for (int i = 0; i < task_num; ++i) { // recalculate task num for each worker.
Worker *worker = workers[i]; int worker_num = static_cast<int>(workers.size());
THREAD_RETURN_IF_NULL(worker); if (worker_num > 0) {
worker->Active(task, i); int each_worker_task_num = task_num / worker_num;
if (worker == curr) { int rest_task_num = task_num % worker_num;
(void)worker->RunLocalKernelTask(); int start = 0;
int end;
for (int i = 0; i < worker_num; ++i) {
Worker *worker = workers[i];
THREAD_RETURN_IF_NULL(worker);
if (i < rest_task_num) {
end = start + each_worker_task_num + 1;
} else {
end = start + each_worker_task_num;
}
worker->Active(task, start, end);
if (worker == curr) {
(void)worker->RunLocalKernelTask();
}
start = end;
} }
} }
} }

View File

@ -75,7 +75,7 @@ class Worker {
// create thread and start running at the same time // create thread and start running at the same time
virtual void CreateThread(); virtual void CreateThread();
// assign task and then activate thread // assign task and then activate thread
void Active(Task *task, int task_id); void Active(Task *task, int task_id_start, int task_id_end);
// activate thread // activate thread
virtual void Active(); virtual void Active();
// whether or not it is idle and marked as held // whether or not it is idle and marked as held
@ -127,7 +127,8 @@ class Worker {
std::condition_variable cond_var_; std::condition_variable cond_var_;
std::atomic<Task *> task_{nullptr}; std::atomic<Task *> task_{nullptr};
std::atomic_int task_id_{0}; std::atomic_int task_id_start_{0};
std::atomic_int task_id_end_{0};
float lhs_scale_{0.}; float lhs_scale_{0.};
float rhs_scale_{kMaxScale}; float rhs_scale_{kMaxScale};
int frequency_{kDefaultFrequency}; int frequency_{kDefaultFrequency};

View File

@ -6,7 +6,7 @@ index 1397073ba..041a3436f 100644
// function supports process affinity. // function supports process affinity.
unsigned get_max_threads_to_use() { unsigned get_max_threads_to_use() {
int num_cores_per_socket = (int)dnnl::impl::cpu::platform::get_num_cores(); int num_cores_per_socket = (int)dnnl::impl::cpu::platform::get_num_cores();
+ if (num_cores_per_socket == 0) + if (num_cores_per_socket <= 1)
+ num_cores_per_socket = std::thread::hardware_concurrency(); + num_cores_per_socket = std::thread::hardware_concurrency();
#if defined(_WIN32) #if defined(_WIN32)
DWORD_PTR proc_affinity_mask; DWORD_PTR proc_affinity_mask;