From d406653e8554e54222ef81495923a0df0ebf04ed Mon Sep 17 00:00:00 2001 From: fan-jibin Date: Wed, 15 Dec 2021 11:46:45 +0800 Subject: [PATCH] optimize cpu ops --- .../cpu/arithmetic_logic_cpu_kernel.cc | 140 ++++++------------ .../cpu/arithmetic_logic_cpu_kernel.h | 3 + .../cpu/arithmetic_self_cpu_kernel.h | 2 + .../kernel_compiler/cpu/assign_cpu_kernel.cc | 52 ++----- .../cpu/embedding_look_up_cpu_kernel.cc | 31 +--- .../kernel_compiler/cpu/gather_cpu_kernel.cc | 43 ++---- .../kernel_compiler/cpu/gather_cpu_kernel.h | 1 - .../cpu/mkldnn/addn_cpu_kernel.cc | 59 +++++--- .../cpu/mkldnn/addn_cpu_kernel.h | 2 + .../cpu/mkldnn/eltwise_cpu_kernel.h | 2 - .../cpu/mkldnn/matmul_cpu_kernel.cc | 76 ++++++---- .../cpu/mkldnn/matmul_cpu_kernel.h | 12 -- .../cpu/transpose_cpu_kernel.cc | 22 +-- .../runtime/framework/actor/actor_common.h | 8 +- 14 files changed, 184 insertions(+), 269 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.cc index 7c337feabf8..28ecf2ef0fa 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.cc @@ -32,137 +32,91 @@ constexpr size_t kOutputsNum = 1; } // namespace template -void ArithmeticLogicCPUKernel::Less(const T *input1, const T *input2, bool *out) { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - if (output_size_ > kMaxLessSerialSize) { - auto task = [&](size_t start, size_t end) { +template +void ArithmeticLogicCPUKernel::BinaryOp(const T *input1, const T *input2, bool *out, Op op) { + size_t input1_size = 1; + size_t input2_size = 2; + + for (size_t i = 0; i < output_shape_.size(); i++) { + input1_size *= input_shape1_[i]; + input2_size *= input_shape2_[i]; + } + + if (input_shape1_ == input_shape2_) { + auto task = [this, input1, input2, out, op](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = op(input1[i], input2[i]); + } + }; + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + } else if (input1_size == 1) { + auto task = [this, input1, input2, out, op](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = op(input1[0], input2[i]); + } + }; + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + } else if (input2_size == 1) { + auto task = [this, input1, input2, out, op](size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = op(input1[i], input2[0]); + } + }; + ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + } else { + BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); + auto task = [this, input1, input2, out, op, &base_iter](size_t start, size_t end) { auto iter = base_iter; iter.SetPos(start); for (size_t i = start; i < end; i++) { auto x = input1[iter.GetInputPosA()]; auto y = input2[iter.GetInputPosB()]; - out[i] = std::less()(x, y); + out[i] = op(x, y); iter.GenNextPos(); } }; ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); - } else { - base_iter.SetPos(0); - for (size_t i = 0; i < output_size_; i++) { - auto x = input1[base_iter.GetInputPosA()]; - auto y = input2[base_iter.GetInputPosB()]; - out[i] = std::less()(x, y); - base_iter.GenNextPos(); - } } } +template +void ArithmeticLogicCPUKernel::Less(const T *input1, const T *input2, bool *out) { + BinaryOp(input1, input2, out, std::less()); +} + template void ArithmeticLogicCPUKernel::Equal(const T *input1, const T *input2, bool *out) { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - auto task = [&](size_t start, size_t end) { - auto iter = base_iter; - iter.SetPos(start); - for (size_t i = start; i < end; i++) { - auto x = input1[iter.GetInputPosA()]; - auto y = input2[iter.GetInputPosB()]; - out[i] = std::equal_to()(x, y); - iter.GenNextPos(); - } - }; - ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + BinaryOp(input1, input2, out, std::equal_to()); } template void ArithmeticLogicCPUKernel::NotEqual(const T *input1, const T *input2, bool *out) { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - auto task = [&](size_t start, size_t end) { - auto iter = base_iter; - iter.SetPos(start); - for (size_t i = start; i < end; i++) { - auto x = input1[iter.GetInputPosA()]; - auto y = input2[iter.GetInputPosB()]; - out[i] = std::not_equal_to()(x, y); - iter.GenNextPos(); - } - }; - ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + BinaryOp(input1, input2, out, std::not_equal_to()); } template void ArithmeticLogicCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out) { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - auto task = [&](size_t start, size_t end) { - auto iter = base_iter; - iter.SetPos(start); - for (size_t i = start; i < end; i++) { - out[i] = input1[iter.GetInputPosA()] && input2[iter.GetInputPosB()]; - iter.GenNextPos(); - } - }; - ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + BinaryOp(input1, input2, out, std::logical_and()); } template void ArithmeticLogicCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out) { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - auto task = [&](size_t start, size_t end) { - auto iter = base_iter; - iter.SetPos(start); - for (size_t i = start; i < end; i++) { - out[i] = input1[iter.GetInputPosA()] || input2[iter.GetInputPosB()]; - iter.GenNextPos(); - } - }; - ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + BinaryOp(input1, input2, out, std::logical_or()); } template void ArithmeticLogicCPUKernel::Greater(const T *input1, const T *input2, bool *out) { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - auto task = [&](size_t start, size_t end) { - auto iter = base_iter; - iter.SetPos(start); - for (size_t i = start; i < end; i++) { - auto x = input1[iter.GetInputPosA()]; - auto y = input2[iter.GetInputPosB()]; - out[i] = std::greater()(x, y); - iter.GenNextPos(); - } - }; - ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + BinaryOp(input1, input2, out, std::greater()); } template void ArithmeticLogicCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out) { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - auto task = [&](size_t start, size_t end) { - auto iter = base_iter; - iter.SetPos(start); - for (size_t i = start; i < end; i++) { - auto x = input1[iter.GetInputPosA()]; - auto y = input2[iter.GetInputPosB()]; - out[i] = std::greater_equal()(x, y); - iter.GenNextPos(); - } - }; - ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + BinaryOp(input1, input2, out, std::greater_equal()); } template void ArithmeticLogicCPUKernel::LessEqual(const T *input1, const T *input2, bool *out) { - BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_); - auto task = [&](size_t start, size_t end) { - auto iter = base_iter; - iter.SetPos(start); - for (size_t i = start; i < end; i++) { - auto x = input1[iter.GetInputPosA()]; - auto y = input2[iter.GetInputPosB()]; - out[i] = std::less_equal()(x, y); - iter.GenNextPos(); - } - }; - ParallelLaunchAutoSearch(task, output_size_, this, ¶llel_search_info_); + BinaryOp(input1, input2, out, std::less_equal()); } template diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h index cdebb3492b1..b4edb633002 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_logic_cpu_kernel.h @@ -39,6 +39,9 @@ class ArithmeticLogicCPUKernel : public CPUKernel { private: void InitComputeFunc(); + template + void BinaryOp(const T *input1, const T *input2, bool *out, Op op); + void Less(const T *input1, const T *input2, bool *out); void Equal(const T *input1, const T *input2, bool *out); void NotEqual(const T *input1, const T *input2, bool *out); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index d58889b81c9..b6942c9fa0c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -156,6 +156,8 @@ MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc index ce0c1d0bfd8..6a54a12baae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc @@ -74,46 +74,26 @@ bool AssignCPUKernel::Launch(const std::vector &inputs, const std::v << "', memcpy size must be less than or equal to max size, but got memcpy size: " << total_size << ", and max size: " << max_size; } - constexpr size_t kBlockSize = 10000; - size_t thread_num = (total_size + kBlockSize - 1) / kBlockSize; - auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); - thread_num = thread_num > max_thread_num ? max_thread_num : thread_num; - if (thread_num == 0) { - return true; - } - size_t stride = total_size / thread_num; - std::vector tasks; - size_t thread_index = 0; + auto input0_addr = reinterpret_cast(inputs[0]->addr); auto input1_addr = reinterpret_cast(inputs[1]->addr); auto output_addr = reinterpret_cast(outputs[0]->addr); - size_t length = stride; - while (thread_index < thread_num) { - auto thread_stride = stride * thread_index; - size_t max_length = total_size - thread_stride; - if (thread_index == thread_num - 1) { - length = max_length; + auto task = [&](size_t start, size_t end) { + int8_t *input0 = input0_addr + start; + int8_t *input1 = input1_addr + start; + int8_t *output = output_addr + start; + size_t length = end - start; + size_t max_length = total_size - start; + int ret = memcpy_s(input0, max_length, input1, length); + if (ret != 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name << "', memcpy_s error. Error no " << ret; } - int8_t *input0 = input0_addr + thread_stride; - int8_t *input1 = input1_addr + thread_stride; - int8_t *output = output_addr + thread_stride; - auto block = [input0, input1, output, max_length, length]() { - int ret = memcpy_s(input0, max_length, input1, length); - if (ret != 0) { - MS_LOG(ERROR) << "For '" << kernel_name << "', memcpy_s error. Error no " << ret; - return common::FAIL; - } - ret = memcpy_s(output, max_length, input1, length); - if (ret != 0) { - MS_LOG(ERROR) << "For '" << kernel_name << "', memcpy_s error. Error no " << ret; - return common::FAIL; - } - return common::SUCCESS; - }; - (void)tasks.emplace_back(block); - thread_index++; - } - (void)common::ThreadPool::GetInstance().SyncRun(tasks); + ret = memcpy_s(output, max_length, input1, length); + if (ret != 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name << "', memcpy_s error. Error no " << ret; + } + }; + ParallelLaunchAutoSearch(task, total_size, this, ¶llel_search_info_); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc index 183593963bf..e27461cd758 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc @@ -106,31 +106,12 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector(inputs[0]->addr); const auto *indices_addr = reinterpret_cast(inputs[1]->addr); auto *output_addr = reinterpret_cast(outputs[0]->addr); - size_t thread_num = indices_lens_ / kBlockSize + 1; - auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); - thread_num = thread_num > max_thread_num ? max_thread_num : thread_num; - std::vector tasks; - size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; - size_t i; - size_t task_offset = 0; - MS_LOG(DEBUG) << "indices_lens_: " << indices_lens_ << " one task proc lens:" << task_proc_lens; - for (i = 0; i < thread_num; i++) { - if (task_offset >= indices_lens_) { - break; - } - MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; - auto task = [input_addr, indices_addr, output_addr, task_offset, task_proc_lens, this]() { - LookUpTableTask(input_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size_, - task_proc_lens, outer_dim_size_, static_cast(offset_), first_dim_size_, kernel_name_); - return common::SUCCESS; - }; - (void)tasks.emplace_back(task); - task_offset += task_proc_lens; - if (task_offset + task_proc_lens > indices_lens_) { - task_proc_lens = indices_lens_ - task_offset; - } - } - (void)common::ThreadPool::GetInstance().SyncRun(tasks); + auto task = [&](size_t start, size_t end) { + size_t task_proc_lens = end - start; + LookUpTableTask(input_addr, indices_addr + start, output_addr + start * outer_dim_size_, task_proc_lens, + outer_dim_size_, static_cast(offset_), first_dim_size_, kernel_name_); + }; + ParallelLaunchAutoSearch(task, indices_lens_, this, ¶llel_search_info_); } bool EmbeddingLookUpCPUKernel::Launch(const std::vector &inputs, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc index 6658de5c27a..d576960f39b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.cc @@ -66,14 +66,6 @@ bool GatherV2CPUKernel::Launch(const std::vector &inputs, axis_ = axis_ + dims; } - int max_thread_num = SizeToInt(common::ThreadPool::GetInstance().GetSyncRunThreadNum()); - ParallelRun(input_tensor, indices_data, output_addr, max_thread_num); - return true; -} - -template -void GatherV2CPUKernel::ParallelRun(const int8_t *input_addr, const int *indices_data, int8_t *output_addr, - int thread_num) { size_t outer_size = 1, inner_size = 1; auto axis = static_cast(axis_); for (size_t i = 0; i < axis; ++i) { @@ -87,32 +79,17 @@ void GatherV2CPUKernel::ParallelRun(const int8_t *input_addr, const int *indi indices_element_size *= indices_shape_.at(i); } auto limit = input_shape_.at(axis); - size_t stride = UP_DIV(outer_size, IntToSize(thread_num)); - std::vector tasks; - int thread_index = 0; - while (thread_index < thread_num) { - int count = MSMIN(SizeToInt(stride), SizeToInt(outer_size) - SizeToInt(stride) * thread_index); - if (count <= 0) { - break; + auto task = [&](size_t start, size_t end) { + int count = SizeToInt(end - start); + const int8_t *in = input_tensor + start * limit * inner_size * sizeof(T); + int8_t *out = output_addr + start * indices_element_size * inner_size * sizeof(T); + int ret = Gather(in, count, inner_size, limit, indices_data, indices_element_size, out, sizeof(T)); + if (ret != 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', error_code[" << ret << "]"; } - auto thread_stride = static_cast(stride * thread_index); - const int8_t *in = input_addr + thread_stride * limit * inner_size * sizeof(T); - int8_t *out = output_addr + thread_stride * indices_element_size * inner_size * sizeof(T); - auto block = [this, in, indices_data, count, inner_size, limit, indices_element_size, out, thread_index]() { - int ret = Gather(in, count, inner_size, limit, indices_data, indices_element_size, out, sizeof(T)); - if (ret != 0) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "', run error task_id[" << thread_index << "] error_code[" << ret - << "]"; - return common::FAIL; - } - return common::SUCCESS; - }; - (void)tasks.emplace_back(block); - thread_index++; - } - if (!common::ThreadPool::GetInstance().SyncRun(tasks)) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', syncRun error."; - } + }; + ParallelLaunchAutoSearch(task, outer_size, this, ¶llel_search_info_); + return true; } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h index eeb55a7117f..bf8231e7ed5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/gather_cpu_kernel.h @@ -37,7 +37,6 @@ class GatherV2CPUKernel : public CPUKernel { const std::vector &outputs) override; private: - void ParallelRun(const int8_t *input_addr, const int *indices_data, int8_t *output_addr, int thread_num); std::vector input_shape_; std::vector indices_shape_; std::vector output_shape_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc index e93e5fd8cff..7c035d6e12f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.cc @@ -35,6 +35,13 @@ void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) { } } +void AddFloat(const float *in_0, const float *in_1, float *out, int start, int end) { + int ret = ElementAdd(in_0 + start, in_1 + start, out + start, end - start); + if (ret != NNACL_OK) { + MS_LOG(EXCEPTION) << "Add failed."; + } +} + void AddDouble(const double *in0, const double *in1, double *out, int start, int end) { for (int index = start; index < end; index++) { out[index] = in0[index] + in1[index]; @@ -81,35 +88,43 @@ bool AddNCPUKernel::Launch(const std::vector &inputs, const ExecutePrimitive(); } } else if (dtype_ == kNumberTypeInt32) { - size_t elements_num = outputs[0]->size / sizeof(int); - const auto input_0 = reinterpret_cast(inputs[0]->addr); - const auto input_1 = reinterpret_cast(inputs[1]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2); - ParallelLaunchAutoSearch(task_0, elements_num, this, ¶llel_search_info_); - for (size_t index = 2; index < input_num_; ++index) { - const auto input = reinterpret_cast(inputs[index]->addr); - auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2); - ParallelLaunchAutoSearch(task, elements_num, this, ¶llel_search_info_); - } + LaunchNnacl(inputs, outputs); } else if (dtype_ == kNumberTypeFloat64) { - size_t elements_num = outputs[0]->size / sizeof(double); - const auto input_0 = reinterpret_cast(inputs[0]->addr); - const auto input_1 = reinterpret_cast(inputs[1]->addr); - auto output = reinterpret_cast(outputs[0]->addr); - auto task_0 = std::bind(AddDouble, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2); - CPUKernelUtils::ParallelFor(task_0, elements_num); - for (size_t index = 2; index < input_num_; ++index) { - const auto input = reinterpret_cast(inputs[index]->addr); - auto task = std::bind(AddDouble, input, output, output, std::placeholders::_1, std::placeholders::_2); - CPUKernelUtils::ParallelFor(task, elements_num); - } + LaunchNnacl(inputs, outputs); } else { MS_LOG(EXCEPTION) << "AddN only support float32, float64 and int32, but got " << TypeIdToType(dtype_)->ToString(); } return true; } +template +void AddNCPUKernel::LaunchNnacl(const std::vector &inputs, + const std::vector &outputs) { + std::function m_func; + if constexpr (std::is_same::value) { + m_func = AddFloat; + } else if constexpr (std::is_same::value) { + m_func = AddInt; + } else if constexpr (std::is_same::value) { + m_func = AddDouble; + } else { + MS_LOG(EXCEPTION) << "AddN only support float32, float64 and int32, but got " << TypeIdToType(dtype_)->ToString(); + } + + size_t elements_num = outputs[0]->size / sizeof(T); + const auto input_0 = reinterpret_cast(inputs[0]->addr); + const auto input_1 = reinterpret_cast(inputs[1]->addr); + auto output = reinterpret_cast(outputs[0]->addr); + auto task_0 = std::bind(m_func, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2); + ParallelLaunchAutoSearch(task_0, elements_num, this, ¶llel_search_info_); + const size_t iter_start = 2; + for (size_t index = iter_start; index < input_num_; ++index) { + const auto input = reinterpret_cast(inputs[index]->addr); + auto task = std::bind(m_func, input, output, output, std::placeholders::_1, std::placeholders::_2); + ParallelLaunchAutoSearch(task, elements_num, this, ¶llel_search_info_); + } +} + void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) { auto src0_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h index 12f02f81c5e..d1c500e8230 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h @@ -33,6 +33,8 @@ class AddNCPUKernel : public MKLCPUKernel { const std::vector &outputs) override; private: + template + void LaunchNnacl(const std::vector &inputs, const std::vector &outputs); void CheckParam(const CNodePtr &kernel_node); size_t input_num_{0}; std::vector output_shape_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h index d3f07f33e6b..4377e1ca4ea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/eltwise_cpu_kernel.h @@ -45,8 +45,6 @@ MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputA EltWiseCPUKernel); MS_REG_CPU_KERNEL(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseCPUKernel); -MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseCPUKernel); MS_REG_CPU_KERNEL(Exp, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseCPUKernel); MS_REG_CPU_KERNEL(Log, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc index e14d58e912c..634ba4a3e3c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.cc @@ -19,6 +19,8 @@ #include "common/thread_pool.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/nnacl/op_base.h" +#include "backend/kernel_compiler/cpu/nnacl/matmul_parameter.h" +#include "backend/kernel_compiler/cpu/nnacl/fp32/matmul_fp32.h" #include "runtime/device/cpu/cpu_device_address.h" #include "utils/ms_utils.h" @@ -27,39 +29,64 @@ namespace kernel { namespace { constexpr size_t kMatMulInputsNum = 2; constexpr size_t kMatMulOutputsNum = 1; -const size_t kIndexOffset = 2; +constexpr size_t kIndexOffset = 2; +constexpr size_t kRankMin = 2; +using dims = dnnl::memory::dims; } // namespace + void MatMulCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); std::vector a_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector b_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector o_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - const size_t rank_min = 2; - if (a_shape.size() < rank_min || b_shape.size() < rank_min || o_shape.size() < rank_min) { - MS_LOG(EXCEPTION) << "The tensor rank of MatMul should be greater than or equal to 2."; + if (a_shape.size() < kRankMin || b_shape.size() < kRankMin || o_shape.size() < kRankMin) { + MS_LOG(EXCEPTION) << "The tensor rank of MatMul should be greater than or equal to " << kRankMin; } bool trans_a = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_A); bool trans_b = AnfAlgo::GetNodeAttr(kernel_node, TRANSPOSE_B); - rank_ = a_shape.size(); - batch_ = 1; - for (size_t i = 0; i < rank_ - kIndexOffset; ++i) { - batch_ *= a_shape[i]; + auto rank = a_shape.size(); + int64_t batch = 1; + for (size_t i = 0; i < rank - kIndexOffset; ++i) { + batch *= SizeToLong(a_shape[i]); } - size_mat_a_ = a_shape[rank_ - kIndexOffset] * a_shape[rank_ - 1]; - size_mat_b_ = b_shape[rank_ - kIndexOffset] * b_shape[rank_ - 1]; - size_mat_o_ = o_shape[rank_ - kIndexOffset] * o_shape[rank_ - 1]; + + int64_t dim_m = SizeToLong(o_shape[rank - kIndexOffset]); + int64_t dim_n = SizeToLong(o_shape[rank - 1]); + int64_t dim_k = 1; if (trans_a) { - trans_a_ = TRANSPOSE_YES; - dim_k_ = static_cast(a_shape[rank_ - kIndexOffset]); + dim_k = SizeToLong(a_shape[rank - kIndexOffset]); } else { - dim_k_ = static_cast(a_shape[rank_ - 1]); + dim_k = SizeToLong(a_shape[rank - 1]); } - if (trans_b) { - trans_b_ = TRANSPOSE_YES; + + dims src_dims, weights_dims, dst_dims, a_strides, b_strides, o_strides; + if (batch > 1) { + src_dims = {batch, dim_m, dim_k}; + weights_dims = {batch, dim_k, dim_n}; + dst_dims = {batch, dim_m, dim_n}; + a_strides = {trans_a ? dims{dim_m * dim_k, 1, dim_m} : dims{dim_m * dim_k, dim_k, 1}}; + b_strides = {trans_b ? dims{dim_n * dim_k, 1, dim_k} : dims{dim_n * dim_k, dim_n, 1}}; + o_strides = {dim_n * dim_m, dim_n, 1}; + } else { + src_dims = {dim_m, dim_k}; + weights_dims = {dim_k, dim_n}; + dst_dims = {dim_m, dim_n}; + a_strides = {trans_a ? dims{1, dim_m} : dims{dim_k, 1}}; + b_strides = {trans_b ? dims{1, dim_k} : dims{dim_n, 1}}; + o_strides = {dim_n, 1}; } - dim_m_ = static_cast(o_shape[rank_ - kIndexOffset]); - dim_n_ = static_cast(o_shape[rank_ - 1]); + + dnnl::memory::desc src_md(src_dims, dnnl::memory::data_type::f32, a_strides); + dnnl::memory::desc weights_md(weights_dims, dnnl::memory::data_type::f32, b_strides); + dnnl::memory::desc dst_md(dst_dims, dnnl::memory::data_type::f32, o_strides); + dnnl::matmul::desc matmul_desc(src_md, weights_md, dst_md); + dnnl::matmul::primitive_desc prim_desc(matmul_desc, MKLKernelEngine::Get().engine()); + primitive_ = std::make_shared(prim_desc); + + AddArgument(DNNL_ARG_SRC, src_md); + AddArgument(DNNL_ARG_WEIGHTS, weights_md); + AddArgument(DNNL_ARG_DST, dst_md); } bool MatMulCPUKernel::Launch(const std::vector &inputs, const std::vector &, @@ -70,15 +97,10 @@ bool MatMulCPUKernel::Launch(const std::vector &inputs, cons const auto input_b = reinterpret_cast(inputs[1]->addr); auto output = reinterpret_cast(outputs[0]->addr); - dnnl_dim_t lda = (trans_a_ == TRANSPOSE_YES ? dim_m_ : dim_k_); - dnnl_dim_t ldb = (trans_b_ == TRANSPOSE_YES ? dim_k_ : dim_n_); - dnnl_dim_t ldc = dim_n_; - float alpha = 1.0; - float beta = 0.0; - for (size_t i = 0; i < batch_; i++) { - (void)dnnl_sgemm(trans_a_, trans_b_, dim_m_, dim_n_, dim_k_, alpha, input_a + i * size_mat_a_, lda, - input_b + i * size_mat_b_, ldb, beta, output + i * size_mat_o_, ldc); - } + SetArgumentHandle(DNNL_ARG_SRC, input_a); + SetArgumentHandle(DNNL_ARG_WEIGHTS, input_b); + SetArgumentHandle(DNNL_ARG_DST, output); + ExecutePrimitive(); return true; } } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h index e3703b5473c..e4aced84281 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/matmul_cpu_kernel.h @@ -32,18 +32,6 @@ class MatMulCPUKernel : public MKLCPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - - private: - dnnl_dim_t dim_m_{0}; - dnnl_dim_t dim_n_{0}; - dnnl_dim_t dim_k_{0}; - size_t batch_{0}; - size_t rank_{0}; - size_t size_mat_a_{0}; - size_t size_mat_b_{0}; - size_t size_mat_o_{0}; - char trans_a_{TRANSPOSE_NO}; - char trans_b_{TRANSPOSE_NO}; }; MS_REG_CPU_KERNEL( MatMul, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc index c4b77f6b11d..a676475b0a3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/transpose_cpu_kernel.cc @@ -133,13 +133,7 @@ void TransposeCPUFwdKernel::LaunchKernel(const std::vector &inputs, template void TransposeCPUFwdKernel::ParallelRun(const T *input_addr, T *output_addr, const int *output_shape, size_t count) { - auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); - const float block_size = 128.0; - const size_t thread_num = - count < block_size * max_thread_num ? FloatToSize(std::ceil(count / block_size)) : max_thread_num; - std::vector tasks; std::function TransposeDims; - if constexpr (std::is_same_v) { TransposeDims = &TransposeDimsInt8; } else if constexpr (std::is_same_v) { @@ -163,14 +157,14 @@ void TransposeCPUFwdKernel::ParallelRun(const T *input_addr, T *output_addr, con } else if constexpr (std::is_same_v) { TransposeDims = &TransposeDimsBool; } - for (int task_id = 0; task_id < SizeToInt(thread_num); ++task_id) { - auto task = [this, &TransposeDims, &input_addr, &output_addr, &output_shape, task_id, thread_num]() { - TransposeDims(input_addr, output_addr, output_shape, &transpose_param_, task_id, SizeToInt(thread_num)); - return common::SUCCESS; - }; - (void)tasks.emplace_back(task); - } - (void)common::ThreadPool::GetInstance().SyncRun(tasks); + auto thread_pool = GetActorMgrInnerThreadPool(); + size_t thread_num = thread_pool->GetKernelThreadNum(); + auto task = [this, &TransposeDims, input_addr, output_addr, output_shape, thread_num](size_t start, size_t end) { + for (size_t idx = start; idx < end; idx++) { + TransposeDims(input_addr, output_addr, output_shape, &transpose_param_, SizeToInt(idx), SizeToInt(thread_num)); + } + }; + ParallelLaunchAutoSearch(task, thread_num, this, ¶llel_search_info_); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_common.h b/mindspore/ccsrc/runtime/framework/actor/actor_common.h index 6a9fa1b339b..a8ef1dac087 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_common.h +++ b/mindspore/ccsrc/runtime/framework/actor/actor_common.h @@ -150,10 +150,10 @@ class ActorDispatcher { // The first five executions are for warm-up, the next five executions are statistics of multi thread execution time, // and the next next five executions are statistics of single thread execution time. - static constexpr size_t kMultiThreadExecutionCountBegin{6}; - static constexpr size_t kMultiThreadExecutionCountEnd{10}; - static constexpr size_t kSingleThreadExecutionCountBegin{11}; - static constexpr size_t kSingleThreadExecutionCountEnd{15}; + static constexpr size_t kMultiThreadExecutionCountBegin{31}; + static constexpr size_t kMultiThreadExecutionCountEnd{40}; + static constexpr size_t kSingleThreadExecutionCountBegin{41}; + static constexpr size_t kSingleThreadExecutionCountEnd{50}; // The single thread execution constraint. static constexpr size_t kSingleThreadExecutionActorMaxNum{100};