From 1d5f77d075e5b3d692710506283cf9a218b42789 Mon Sep 17 00:00:00 2001 From: wuxuejian Date: Wed, 31 Mar 2021 09:43:26 +0800 Subject: [PATCH] Refator eltwisegrad cpu ops --- .../cpu/eltwise_grad_cpu_kernel.cc | 173 +++++------------- .../cpu/eltwise_grad_cpu_kernel.h | 84 ++++----- 2 files changed, 77 insertions(+), 180 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc index f38180dece7..97ad82e77c4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc @@ -14,15 +14,16 @@ * limitations under the License. */ #include -#include -#include +#include #include "backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h" +#include "common/thread_pool.h" #include "runtime/device/cpu/cpu_device_address.h" namespace mindspore { namespace kernel { + template -void EltWiseGradCPUKernel::ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { if (input2[i] > 0) { out[i] = input1[i]; @@ -33,7 +34,7 @@ void EltWiseGradCPUKernel::ReluGrad(const T *input1, const T *input2, T *out, si } template -void EltWiseGradCPUKernel::ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { if (input2[i] > 0 && input2[i] <= 6) { out[i] = input1[i]; @@ -44,7 +45,7 @@ void EltWiseGradCPUKernel::ReLU6Grad(const T *input1, const T *input2, T *out, s } template -void EltWiseGradCPUKernel::AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { if (input1[i] > 0) { out[i] = input2[i]; @@ -57,21 +58,21 @@ void EltWiseGradCPUKernel::AbsGrad(const T *input1, const T *input2, T *out, siz } template -void EltWiseGradCPUKernel::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = input2[i] * input1[i] * (1 - input1[i]); } } template -void EltWiseGradCPUKernel::SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { out[i] = input2[i] / (input1[i] * 2); } } template -void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { T tmp = input1[i] * input1[i]; out[i] = input2[i] * (1 - tmp); @@ -79,7 +80,7 @@ void EltWiseGradCPUKernel::TanhGrad(const T *input1, const T *input2, T *out, si } template -void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { T x = input2[i]; auto double_x = static_cast(x); @@ -91,7 +92,7 @@ void EltWiseGradCPUKernel::GeluGrad(const T *input1, const T *input2, T *out, si } template -void EltWiseGradCPUKernel::AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { T dividend = input2[i]; T divisor = sqrt(1 - input1[i] * input1[i]); @@ -112,7 +113,7 @@ void EltWiseGradCPUKernel::AsinGrad(const T *input1, const T *input2, T *out, si } template -void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { T dividend = -input2[i]; T divisor = sqrt(1 - input1[i] * input1[i]); @@ -133,7 +134,7 @@ void EltWiseGradCPUKernel::ACosGrad(const T *input1, const T *input2, T *out, si } template -void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { T dividend = input2[i]; T divisor = 1 + input1[i] * input1[i]; @@ -154,7 +155,7 @@ void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, si } template -void EltWiseGradCPUKernel::AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { T dividend = input2[i]; T divisor = sqrt(1 + input1[i] * input1[i]); @@ -175,7 +176,7 @@ void EltWiseGradCPUKernel::AsinhGrad(const T *input1, const T *input2, T *out, s } template -void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { +void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { for (size_t i = start; i < end; i++) { T dividend = input2[i]; T divisor = sqrt(input1[i] * input1[i] - 1); @@ -195,132 +196,46 @@ void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out, s } } -void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == "ReluGrad") { - operate_type_ = RELUGRAD; - } else if (kernel_name == "ReLU6Grad") { - operate_type_ = RELU6GRAD; - } else if (kernel_name == "SigmoidGrad") { - operate_type_ = SIGMOIDGRAD; - } else if (kernel_name == "AbsGrad") { - operate_type_ = ABSGRAD; - } else if (kernel_name == "TanhGrad") { - operate_type_ = TANHGRAD; - } else if (kernel_name == "SqrtGrad") { - operate_type_ = SQRTGRAD; - } else if (kernel_name == "GeLUGrad") { - operate_type_ = GELUGRAD; - } else if (kernel_name == "AsinGrad") { - operate_type_ = ASINGRAD; - } else if (kernel_name == "ACosGrad") { - operate_type_ = ACOSGRAD; - } else if (kernel_name == "AtanGrad") { - operate_type_ = ATANGRAD; - } else if (kernel_name == "AsinhGrad") { - operate_type_ = ASINHGRAD; - } else if (kernel_name == "AcoshGrad") { - operate_type_ = ACOSHGRAD; - } else { - MS_LOG(EXCEPTION) << "Not support " << kernel_name; - } - - input_shape0_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input_shape1_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - if (output_shape_.size() == 0) { - output_shape_.insert(output_shape_.begin(), 1); - } - size_t l = input_shape0_.size(); - for (size_t i = 0; i < output_shape_.size() - l; ++i) { - input_shape0_.insert(input_shape0_.begin(), 1); - } - l = input_shape1_.size(); - for (size_t i = 0; i < output_shape_.size() - l; ++i) { - input_shape1_.insert(input_shape1_.begin(), 1); - } - CPUKernelUtils::GetElementNumEveryDim(input_shape0_, &input_element_num0_); - CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_); - CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_); - dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); - if (dtype_ != AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1)) { - MS_LOG(EXCEPTION) << "Input0 and input1 must has the same data type"; - } -} - -bool EltWiseGradCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - if (dtype_ == kNumberTypeInt32 || dtype_ == kNumberTypeInt16) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16 || dtype_ == kNumberTypeFloat64) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeInt64) { - LaunchKernel(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "Data type is " << TypeIdLabel(dtype_) << "is not support."; - } - return true; + kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); } template -void EltWiseGradCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { +bool EltWiseGradCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + static const std::map> + elt_map{{"ReluGrad", &EltWiseGradCPUKernel::ReluGrad}, {"ReLU6Grad", &EltWiseGradCPUKernel::ReLU6Grad}, + {"SigmoidGrad", &EltWiseGradCPUKernel::SigmoidGrad}, {"AbsGrad", &EltWiseGradCPUKernel::AbsGrad}, + {"TanhGrad", &EltWiseGradCPUKernel::TanhGrad}, {"SqrtGrad", &EltWiseGradCPUKernel::SqrtGrad}, + {"GeLUGrad", &EltWiseGradCPUKernel::GeluGrad}, {"AsinGrad", &EltWiseGradCPUKernel::AsinGrad}, + {"ACosGrad", &EltWiseGradCPUKernel::ACosGrad}, {"AtanGrad", &EltWiseGradCPUKernel::AtanGrad}, + {"AsinhGrad", &EltWiseGradCPUKernel::AsinhGrad}, {"AcoshGrad", &EltWiseGradCPUKernel::AcoshGrad}}; T *input1 = reinterpret_cast(inputs[0]->addr); T *input2 = reinterpret_cast(inputs[1]->addr); T *output = reinterpret_cast(outputs[0]->addr); - size_t lens = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; - auto max_thread_num = std::thread::hardware_concurrency(); - size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; - MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; - std::vector threads; - if (thread_num < 1) { - MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; - return; - } - threads.reserve(thread_num); + size_t count = outputs[0]->size > 0 ? static_cast(outputs[0]->size / sizeof(T)) : 1; + auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); + const float block_size = 128.0; + size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num; + std::vector tasks; size_t start = 0; - size_t once_compute_size = (lens + thread_num - 1) / thread_num; - if (once_compute_size < 1) { - MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; - return; - } - while (start < lens) { - size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); - if (operate_type_ == RELUGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ReluGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == RELU6GRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ReLU6Grad, this, input1, input2, output, start, end)); - } else if (operate_type_ == ABSGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AbsGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == SIGMOIDGRAD) { - threads.emplace_back( - std::thread(&EltWiseGradCPUKernel::SigmoidGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == TANHGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::TanhGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == SQRTGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::SqrtGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == GELUGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::GeluGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == ASINGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == ACOSGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == ATANGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AtanGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == ASINHGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinhGrad, this, input1, input2, output, start, end)); - } else if (operate_type_ == ACOSHGRAD) { - threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AcoshGrad, this, input1, input2, output, start, end)); - } else { - MS_LOG(EXCEPTION) << "Not support " << operate_type_; - } + size_t once_compute_size = (count + thread_num - 1) / thread_num; + while (start < count) { + size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size); + auto block = [&, start, end]() { + elt_map.at(kernel_name_)(this, input1, input2, output, start, end); + return common::SUCCESS; + }; + tasks.emplace_back(block); start += once_compute_size; } - for (size_t i = 0; i < threads.size(); ++i) { - threads[i].join(); - } + common::ThreadPool::GetInstance().SyncRun(tasks); + return true; } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h index b2ed04cf1b9..89917e44f9d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h @@ -18,11 +18,13 @@ #include #include #include +#include #include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" namespace mindspore { namespace kernel { +template class EltWiseGradCPUKernel : public CPUKernel { public: EltWiseGradCPUKernel() = default; @@ -32,95 +34,75 @@ class EltWiseGradCPUKernel : public CPUKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - template - void LaunchKernel(const std::vector &inputs, const std::vector &outputs); private: - template void ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - template void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); - std::vector input_shape0_; - std::vector input_shape1_; - std::vector input_element_num0_; - std::vector input_element_num1_; - std::vector output_shape_; - std::vector output_element_num_; - OperateType operate_type_{RELUGRAD}; - TypeId dtype_{kTypeUnknown}; + + std::string kernel_name_ = ""; }; -MS_REG_CPU_KERNEL( +MS_REG_CPU_KERNEL_T( ReluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( ReLU6Grad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( AbsGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( SigmoidGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( SqrtGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( TanhGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL(GeLUGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T(GeLUGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( AsinGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( ACosGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( AtanGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( AsinhGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( + EltWiseGradCPUKernel, float); +MS_REG_CPU_KERNEL_T( AcoshGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - EltWiseGradCPUKernel); + EltWiseGradCPUKernel, float); } // namespace kernel } // namespace mindspore