From 9d9fb1895b53dd56dc4f0d27cfb1f83b15d4fe3b Mon Sep 17 00:00:00 2001 From: zhujingxuan Date: Mon, 16 May 2022 15:10:09 +0800 Subject: [PATCH] Add MatchKernelHelper to remove duplicate code --- mindspore/ccsrc/kernel/common_utils.h | 31 +++++++++++++ .../device/cpu/kernel/lerp_cpu_kernel.cc | 45 ++++++++----------- .../device/cpu/kernel/lerp_cpu_kernel.h | 18 ++++---- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/mindspore/ccsrc/kernel/common_utils.h b/mindspore/ccsrc/kernel/common_utils.h index 07f73941453..70e8b237e30 100644 --- a/mindspore/ccsrc/kernel/common_utils.h +++ b/mindspore/ccsrc/kernel/common_utils.h @@ -311,6 +311,37 @@ inline std::map GetKernelDepends(const CNodePtr &cn return std::map(); } +template +class MatchKernelHelper { + public: + using KernelRunFunc = std::function &, const std::vector &, + const std::vector &)>; + virtual const std::vector> &GetFuncList() const = 0; + + protected: + std::vector GetOpSupport() { + auto &func_list = static_cast(this)->GetFuncList(); + std::vector support_list; + (void)std::transform(func_list.begin(), func_list.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; + } + bool MatchKernelFunc(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + auto kernel_name = base_operator->name(); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto &func_list = static_cast(this)->GetFuncList(); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "The kernel '" << kernel_name << "' does not support this kernel data type: " << kernel_attr; + return false; + } + kernel_func_ = func_list[index].second; + return true; + } + KernelRunFunc kernel_func_; +}; + #define CHECK_KERNEL_INPUTS_NUM(actual_inputs_num, expect_inputs_num, kernel_name) \ do { \ if ((actual_inputs_num) != (expect_inputs_num)) { \ diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/lerp_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/lerp_cpu_kernel.cc index 517da41dee6..fc25de3093f 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/lerp_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/lerp_cpu_kernel.cc @@ -30,13 +30,9 @@ bool LerpCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid."; return false; } - auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); - auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); - if (!is_match) { - MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; - return false; + if (auto ret = MatchKernelFunc(base_operator, inputs, outputs); !ret) { + return ret; } - kernel_func_ = func_list_[index].second; return true; } @@ -72,7 +68,7 @@ void LerpCpuKernelMod::ResetResource() noexcept { } template -bool LerpCpuKernelMod::LaunchKernel(const std::vector &inputs, +bool LerpCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &, const std::vector &outputs) { if (start_shape_ == end_shape_ && start_shape_ == weight_shape_) { auto *input_start = reinterpret_cast(inputs.at(kIndex0)->addr); @@ -111,25 +107,22 @@ bool LerpCpuKernelMod::LaunchKernel(const std::vector &input return true; } -std::vector> LerpCpuKernelMod::func_list_ = { - {KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - &LerpCpuKernelMod::LaunchKernel}, - {KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - &LerpCpuKernelMod::LaunchKernel}}; - -std::vector LerpCpuKernelMod::GetOpSupport() { - std::vector support_list; - (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), - [](const std::pair &pair) { return pair.first; }); - return support_list; +const std::vector> &LerpCpuKernelMod::GetFuncList() const { + static const std::vector> func_list = { + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &LerpCpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &LerpCpuKernelMod::LaunchKernel}, + }; + return func_list; } MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Lerp, LerpCpuKernelMod); } // namespace kernel diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/lerp_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/lerp_cpu_kernel.h index c7b66e80919..aca7806acec 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/lerp_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/lerp_cpu_kernel.h @@ -24,17 +24,18 @@ #include #include "plugin/device/cpu/kernel/cpu_kernel.h" #include "plugin/factory/ms_factory.h" +#include "kernel/common_utils.h" namespace mindspore { namespace kernel { -class LerpCpuKernelMod : public NativeCpuKernelMod { +class LerpCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { public: LerpCpuKernelMod() = default; ~LerpCpuKernelMod() override = default; - bool Launch(const std::vector &inputs, const std::vector &, + bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); + return kernel_func_(this, inputs, workspace, outputs); } bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, @@ -45,21 +46,20 @@ class LerpCpuKernelMod : public NativeCpuKernelMod { void ResetResource() noexcept; + const std::vector> &GetFuncList() const override; + protected: - std::vector GetOpSupport() override; + std::vector GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); } private: template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); - using LerpFunc = std::function &, - const std::vector &)>; + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); size_t output_size_{1}; - LerpFunc kernel_func_; std::vector start_shape_; std::vector end_shape_; std::vector weight_shape_; std::vector output_shape_; - static std::vector> func_list_; }; } // namespace kernel } // namespace mindspore