acc worker thread and mkl thread

This commit is contained in:
fangzehua 2022-05-10 11:01:18 +08:00
parent b1023addba
commit f8268506da
5 changed files with 70 additions and 36 deletions

View File

@ -183,34 +183,44 @@ void MKLCpuKernelMod::ExecutePrimitive() {
MS_EXCEPTION_IF_NULL(primitive_);
#ifdef USE_MS_THREADPOOL_FOR_DNNL
// add auto search
const size_t MAX_POW = 6;
const size_t AVG_COUNT = 5;
const size_t DIFF = 2;
size_t current_pow = parallel_search_info_.search_count / AVG_COUNT;
int current_thread_nums = static_cast<int>(std::pow(2.0f, current_pow));
const std::vector<size_t> kSearchThreadList{4, 8, 16, 24, 32};
const size_t kAvgCount = 5;
const size_t kDiff = 2;
size_t current_pow = parallel_search_info_.search_count / kAvgCount;
auto mkl_pool = dynamic_cast<mkl_threadpool *>(mkl_threadpool_.get());
if (current_pow < MAX_POW) {
if (parallel_search_info_.search_count % AVG_COUNT == 0) {
if (current_pow < kSearchThreadList.size()) {
if (parallel_search_info_.search_count % kAvgCount == 0) {
parallel_search_info_.tmp_sum_cost_time = 0;
}
double start_time = GetTime();
int current_thread_nums = kSearchThreadList[current_pow];
mkl_pool->set_num_threads(current_thread_nums);
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
primitive_->execute(stream_, arguments_);
MS_LOG(DEBUG) << "end to invoke primitive::execute";
double cost_time = GetTime() - start_time;
parallel_search_info_.tmp_sum_cost_time += cost_time;
// skip the first step to warm up.
if (parallel_search_info_.search_count != 0) {
parallel_search_info_.tmp_sum_cost_time += cost_time;
}
parallel_search_info_.search_count++;
if (parallel_search_info_.search_count % AVG_COUNT == 0) {
if (parallel_search_info_.min_cost_time > parallel_search_info_.tmp_sum_cost_time) {
parallel_search_info_.min_cost_time = parallel_search_info_.tmp_sum_cost_time;
if (parallel_search_info_.search_count % kAvgCount == 0) {
double avg_time = 0;
// first avg will skip the first step
if (parallel_search_info_.search_count / kAvgCount == 0) {
avg_time = parallel_search_info_.tmp_sum_cost_time / (kAvgCount - 1);
} else {
avg_time = parallel_search_info_.tmp_sum_cost_time / kAvgCount;
}
if (parallel_search_info_.min_cost_time > avg_time) {
parallel_search_info_.min_cost_time = avg_time;
parallel_search_info_.best_pow = current_pow;
} else if (current_pow - parallel_search_info_.best_pow >= DIFF) {
parallel_search_info_.search_count = AVG_COUNT * MAX_POW;
} else if (current_pow - parallel_search_info_.best_pow >= kDiff) {
parallel_search_info_.search_count = kAvgCount * kSearchThreadList.size();
}
}
} else {
int best_thread_nums = static_cast<int>(std::pow(2.0f, parallel_search_info_.best_pow));
int best_thread_nums = kSearchThreadList[parallel_search_info_.best_pow];
mkl_pool->set_num_threads(best_thread_nums);
MS_LOG(DEBUG) << "begin to invoke primitive::execute";
primitive_->execute(stream_, arguments_);

View File

@ -134,14 +134,19 @@ class mkl_threadpool : public dnnl::threadpool_interop::threadpool_iface {
private:
ActorThreadPool *tp_;
int thread_num_{8};
bool first_parallel{true};
public:
explicit mkl_threadpool(ActorThreadPool *tp) { tp_ = tp; }
void set_num_threads(int num) { thread_num_ = num; }
int get_num_threads() const override { return std::min(SizeToInt(tp_->GetKernelThreadNum()), thread_num_); }
bool get_in_parallel() const override { return false; }
bool get_in_parallel() const override { return !first_parallel; }
uint64_t get_flags() const override { return 0; }
void parallel_for(int n, const std::function<void(int, int)> &fn) override {
bool need_change_flag = first_parallel ? true : false;
if (need_change_flag) {
first_parallel = false;
}
int nthr = get_num_threads();
int n_jobs = std::min(n, nthr);
auto func = [&, n_jobs](void *, int i, float, float) {
@ -149,6 +154,9 @@ class mkl_threadpool : public dnnl::threadpool_interop::threadpool_iface {
return 0;
};
(void)tp_->ParallelLaunch(func, nullptr, n_jobs);
if (need_change_flag) {
first_parallel = true;
}
}
};
#endif

View File

@ -106,10 +106,13 @@ bool Worker::RunLocalKernelTask() {
if (task == nullptr) {
return false;
}
int task_id = task_id_.load(std::memory_order_consume);
task->status |= task->func(task->content, task_id, lhs_scale_, rhs_scale_);
int task_id_start = task_id_start_.load(std::memory_order_consume);
int task_id_end = task_id_end_.load(std::memory_order_consume);
for (int i = task_id_start; i < task_id_end; ++i) {
task->status |= task->func(task->content, i, lhs_scale_, rhs_scale_);
}
task_.store(nullptr, std::memory_order_relaxed);
(void)++task->finished;
task->finished += task_id_end - task_id_start;
return true;
}
@ -134,11 +137,12 @@ void Worker::set_scale(float lhs_scale, float rhs_scale) {
rhs_scale_ = rhs_scale;
}
void Worker::Active(Task *task, int task_id) {
void Worker::Active(Task *task, int task_id_start, int task_id_end) {
{
std::lock_guard<std::mutex> _l(mutex_);
THREAD_TEST_TRUE(task_ == nullptr);
task_id_.store(task_id, std::memory_order_relaxed);
task_id_start_.store(task_id_start, std::memory_order_relaxed);
task_id_end_.store(task_id_end, std::memory_order_relaxed);
task_.store(task, std::memory_order_release);
status_ = kThreadBusy;
}
@ -236,6 +240,7 @@ int ThreadPool::SyncRunFunc(const Func &func, Content content, int start, int en
void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const {
int sum_frequency = 0;
std::vector<Worker *> assigned;
assigned.reserve(task_num);
int num = static_cast<int>(workers_.size()) - 1;
int offset = 0;
bool use_curr = (curr != nullptr) ? curr->get_task_free() : false;
@ -256,13 +261,9 @@ void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const {
}
}
// when there are not enough free threads,
// distribute other tasks to the master thread
if (use_curr) {
for (; count < task_num; ++count) {
assigned.push_back(curr);
sum_frequency += curr->frequency();
}
assigned.push_back(curr);
sum_frequency += curr->frequency();
} else if (assigned.size() != static_cast<size_t>(task_num)) {
CalculateScales(assigned, sum_frequency);
ActiveWorkers(assigned, task, assigned.size(), curr);
@ -271,7 +272,7 @@ void ThreadPool::DistributeTask(Task *task, int task_num, Worker *curr) const {
}
CalculateScales(assigned, sum_frequency);
ActiveWorkers(assigned, task, assigned.size(), curr);
ActiveWorkers(assigned, task, task_num, curr);
}
void ThreadPool::CalculateScales(const std::vector<Worker *> &assigned, int sum_frequency) const {
@ -292,12 +293,26 @@ void ThreadPool::CalculateScales(const std::vector<Worker *> &assigned, int sum_
void ThreadPool::ActiveWorkers(const std::vector<Worker *> &workers, Task *task, int task_num,
const Worker *curr) const {
for (int i = 0; i < task_num; ++i) {
Worker *worker = workers[i];
THREAD_RETURN_IF_NULL(worker);
worker->Active(task, i);
if (worker == curr) {
(void)worker->RunLocalKernelTask();
// recalculate task num for each worker.
int worker_num = static_cast<int>(workers.size());
if (worker_num > 0) {
int each_worker_task_num = task_num / worker_num;
int rest_task_num = task_num % worker_num;
int start = 0;
int end;
for (int i = 0; i < worker_num; ++i) {
Worker *worker = workers[i];
THREAD_RETURN_IF_NULL(worker);
if (i < rest_task_num) {
end = start + each_worker_task_num + 1;
} else {
end = start + each_worker_task_num;
}
worker->Active(task, start, end);
if (worker == curr) {
(void)worker->RunLocalKernelTask();
}
start = end;
}
}
}

View File

@ -75,7 +75,7 @@ class Worker {
// create thread and start running at the same time
virtual void CreateThread();
// assign task and then activate thread
void Active(Task *task, int task_id);
void Active(Task *task, int task_id_start, int task_id_end);
// activate thread
virtual void Active();
// whether or not it is idle and marked as held
@ -127,7 +127,8 @@ class Worker {
std::condition_variable cond_var_;
std::atomic<Task *> task_{nullptr};
std::atomic_int task_id_{0};
std::atomic_int task_id_start_{0};
std::atomic_int task_id_end_{0};
float lhs_scale_{0.};
float rhs_scale_{kMaxScale};
int frequency_{kDefaultFrequency};

View File

@ -6,7 +6,7 @@ index 1397073ba..041a3436f 100644
// function supports process affinity.
unsigned get_max_threads_to_use() {
int num_cores_per_socket = (int)dnnl::impl::cpu::platform::get_num_cores();
+ if (num_cores_per_socket == 0)
+ if (num_cores_per_socket <= 1)
+ num_cores_per_socket = std::thread::hardware_concurrency();
#if defined(_WIN32)
DWORD_PTR proc_affinity_mask;