forked from mindspore-Ecosystem/mindspore
!34478 Add MatchKernelHelper to remove code for solving kernel dtype issue
Merge pull request !34478 from zhujingxuan/CRTP
This commit is contained in:
commit
49317d8022
|
@ -311,6 +311,37 @@ inline std::map<uint32_t, tensor::TensorPtr> GetKernelDepends(const CNodePtr &cn
|
|||
return std::map<uint32_t, tensor::TensorPtr>();
|
||||
}
|
||||
|
||||
template <typename Derived>
|
||||
class MatchKernelHelper {
|
||||
public:
|
||||
using KernelRunFunc = std::function<bool(Derived *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &)>;
|
||||
virtual const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const = 0;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() {
|
||||
auto &func_list = static_cast<Derived *>(this)->GetFuncList();
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list.begin(), func_list.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, KernelRunFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
bool MatchKernelFunc(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_name = base_operator->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto &func_list = static_cast<Derived *>(this)->GetFuncList();
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "The kernel '" << kernel_name << "' does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list[index].second;
|
||||
return true;
|
||||
}
|
||||
KernelRunFunc kernel_func_;
|
||||
};
|
||||
|
||||
#define CHECK_KERNEL_INPUTS_NUM(actual_inputs_num, expect_inputs_num, kernel_name) \
|
||||
do { \
|
||||
if ((actual_inputs_num) != (expect_inputs_num)) { \
|
||||
|
|
|
@ -30,13 +30,9 @@ bool LerpCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vec
|
|||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
if (auto ret = MatchKernelFunc(base_operator, inputs, outputs); !ret) {
|
||||
return ret;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -72,7 +68,7 @@ void LerpCpuKernelMod::ResetResource() noexcept {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (start_shape_ == end_shape_ && start_shape_ == weight_shape_) {
|
||||
auto *input_start = reinterpret_cast<T *>(inputs.at(kIndex0)->addr);
|
||||
|
@ -111,25 +107,22 @@ bool LerpCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &input
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, LerpCpuKernelMod::LerpFunc>> LerpCpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&LerpCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&LerpCpuKernelMod::LaunchKernel<float>}};
|
||||
|
||||
std::vector<KernelAttr> LerpCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, LerpFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
const std::vector<std::pair<KernelAttr, LerpCpuKernelMod::KernelRunFunc>> &LerpCpuKernelMod::GetFuncList() const {
|
||||
static const std::vector<std::pair<KernelAttr, LerpCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
&LerpCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&LerpCpuKernelMod::LaunchKernel<float>},
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Lerp, LerpCpuKernelMod);
|
||||
} // namespace kernel
|
||||
|
|
|
@ -24,17 +24,18 @@
|
|||
#include <functional>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "kernel/common_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class LerpCpuKernelMod : public NativeCpuKernelMod {
|
||||
class LerpCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<LerpCpuKernelMod> {
|
||||
public:
|
||||
LerpCpuKernelMod() = default;
|
||||
~LerpCpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -45,21 +46,20 @@ class LerpCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void ResetResource() noexcept;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
std::vector<KernelAttr> GetOpSupport() override { return MatchKernelHelper::GetOpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using LerpFunc = std::function<bool(LerpCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
size_t output_size_{1};
|
||||
LerpFunc kernel_func_;
|
||||
std::vector<size_t> start_shape_;
|
||||
std::vector<size_t> end_shape_;
|
||||
std::vector<size_t> weight_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
static std::vector<std::pair<KernelAttr, LerpFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue