forked from mindspore-Ecosystem/mindspore
Add MatchKernelHelper to remove duplicate code
This commit is contained in:
parent
2550449223
commit
9d9fb1895b
|
@ -311,6 +311,37 @@ inline std::map<uint32_t, tensor::TensorPtr> GetKernelDepends(const CNodePtr &cn
|
||||||
return std::map<uint32_t, tensor::TensorPtr>();
|
return std::map<uint32_t, tensor::TensorPtr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Derived>
|
||||||
|
class MatchKernelHelper {
|
||||||
|
public:
|
||||||
|
using KernelRunFunc = std::function<bool(Derived *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &)>;
|
||||||
|
virtual const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<KernelAttr> GetOpSupport() {
|
||||||
|
auto &func_list = static_cast<Derived *>(this)->GetFuncList();
|
||||||
|
std::vector<KernelAttr> support_list;
|
||||||
|
(void)std::transform(func_list.begin(), func_list.end(), std::back_inserter(support_list),
|
||||||
|
[](const std::pair<KernelAttr, KernelRunFunc> &pair) { return pair.first; });
|
||||||
|
return support_list;
|
||||||
|
}
|
||||||
|
bool MatchKernelFunc(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
const std::vector<KernelTensorPtr> &outputs) {
|
||||||
|
auto kernel_name = base_operator->name();
|
||||||
|
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||||
|
auto &func_list = static_cast<Derived *>(this)->GetFuncList();
|
||||||
|
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||||
|
if (!is_match) {
|
||||||
|
MS_LOG(ERROR) << "The kernel '" << kernel_name << "' does not support this kernel data type: " << kernel_attr;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernel_func_ = func_list[index].second;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
KernelRunFunc kernel_func_;
|
||||||
|
};
|
||||||
|
|
||||||
#define CHECK_KERNEL_INPUTS_NUM(actual_inputs_num, expect_inputs_num, kernel_name) \
|
#define CHECK_KERNEL_INPUTS_NUM(actual_inputs_num, expect_inputs_num, kernel_name) \
|
||||||
do { \
|
do { \
|
||||||
if ((actual_inputs_num) != (expect_inputs_num)) { \
|
if ((actual_inputs_num) != (expect_inputs_num)) { \
|
||||||
|
|
|
@ -30,13 +30,9 @@ bool LerpCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec
|
||||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
if (auto ret = MatchKernelFunc(base_operator, inputs, outputs); !ret) {
|
||||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
return ret;
|
||||||
if (!is_match) {
|
|
||||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
kernel_func_ = func_list_[index].second;
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +68,7 @@ void LerpCpuKernelMod::ResetResource() noexcept {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||||
const std::vector<kernel::AddressPtr> &outputs) {
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
if (start_shape_ == end_shape_ && start_shape_ == weight_shape_) {
|
if (start_shape_ == end_shape_ && start_shape_ == weight_shape_) {
|
||||||
auto *input_start = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
auto *input_start = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||||
|
@ -111,7 +107,8 @@ bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &input
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<KernelAttr, LerpCpuKernelMod::LerpFunc>> LerpCpuKernelMod::func_list_ = {
|
const std::vector<std::pair<KernelAttr, LerpCpuKernelMod::KernelRunFunc>> &LerpCpuKernelMod::GetFuncList() const {
|
||||||
|
static const std::vector<std::pair<KernelAttr, LerpCpuKernelMod::KernelRunFunc>> func_list = {
|
||||||
{KernelAttr()
|
{KernelAttr()
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
.AddInputAttr(kNumberTypeFloat16)
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
@ -123,13 +120,9 @@ std::vector<std::pair<KernelAttr, LerpCpuKernelMod::LerpFunc>> LerpCpuKernelMod:
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddInputAttr(kNumberTypeFloat32)
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
.AddOutputAttr(kNumberTypeFloat32),
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
&LerpCpuKernelMod::LaunchKernel<float>}};
|
&LerpCpuKernelMod::LaunchKernel<float>},
|
||||||
|
};
|
||||||
std::vector<KernelAttr> LerpCpuKernelMod::GetOpSupport() {
|
return func_list;
|
||||||
std::vector<KernelAttr> support_list;
|
|
||||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
|
||||||
[](const std::pair<KernelAttr, LerpFunc> &pair) { return pair.first; });
|
|
||||||
return support_list;
|
|
||||||
}
|
}
|
||||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Lerp, LerpCpuKernelMod);
|
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Lerp, LerpCpuKernelMod);
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
|
|
|
@ -24,17 +24,18 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||||
#include "plugin/factory/ms_factory.h"
|
#include "plugin/factory/ms_factory.h"
|
||||||
|
#include "kernel/common_utils.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
class LerpCpuKernelMod : public NativeCpuKernelMod {
|
class LerpCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<LerpCpuKernelMod> {
|
||||||
public:
|
public:
|
||||||
LerpCpuKernelMod() = default;
|
LerpCpuKernelMod() = default;
|
||||||
~LerpCpuKernelMod() override = default;
|
~LerpCpuKernelMod() override = default;
|
||||||
|
|
||||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs) override {
|
const std::vector<AddressPtr> &outputs) override {
|
||||||
return kernel_func_(this, inputs, outputs);
|
return kernel_func_(this, inputs, workspace, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||||
|
@ -45,21 +46,20 @@ class LerpCpuKernelMod : public NativeCpuKernelMod {
|
||||||
|
|
||||||
void ResetResource() noexcept;
|
void ResetResource() noexcept;
|
||||||
|
|
||||||
|
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<KernelAttr> GetOpSupport() override;
|
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
using LerpFunc = std::function<bool(LerpCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
const std::vector<kernel::AddressPtr> &outputs);
|
||||||
const std::vector<kernel::AddressPtr> &)>;
|
|
||||||
size_t output_size_{1};
|
size_t output_size_{1};
|
||||||
LerpFunc kernel_func_;
|
|
||||||
std::vector<size_t> start_shape_;
|
std::vector<size_t> start_shape_;
|
||||||
std::vector<size_t> end_shape_;
|
std::vector<size_t> end_shape_;
|
||||||
std::vector<size_t> weight_shape_;
|
std::vector<size_t> weight_shape_;
|
||||||
std::vector<size_t> output_shape_;
|
std::vector<size_t> output_shape_;
|
||||||
static std::vector<std::pair<KernelAttr, LerpFunc>> func_list_;
|
|
||||||
};
|
};
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue