!34556 Refator the EltWiseCpuKernel by new interfaces.

Merge pull request !34556 from liqiliang/eltwise-cpu
This commit is contained in:
i-robot 2022-05-18 07:12:12 +00:00 committed by Gitee
commit 55f491ff46
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 153 additions and 79 deletions

View File

@ -15,7 +15,6 @@
*/
#include "plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.h"
#include <algorithm>
#include <cmath>
#include <complex>
@ -24,9 +23,7 @@
#include <thread>
#include <unordered_map>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/mkldnn/eltwise_cpu_kernel.h"
namespace mindspore {
namespace kernel {
@ -620,26 +617,6 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressP
func_pair->second(this, input, output, lens);
}
// MKLDNN Sqrt
class SqrtMKLKernelFunc : public CpuKernelFunc, private EltWiseCpuKernelMod {
public:
SqrtMKLKernelFunc() : EltWiseCpuKernelMod(kSqrt) {}
~SqrtMKLKernelFunc() override = default;
void InitFunc(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name != kSqrt) {
MS_LOG(EXCEPTION) << "Must be " << kSqrt << ", but got " << kernel_name;
}
EltWiseCpuKernelMod::InitKernel(kernel_node);
}
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return EltWiseCpuKernelMod::Launch(inputs, workspace, outputs);
}
};
std::shared_ptr<CpuKernelFunc> CreateArithSelfFunc() { return std::make_shared<ArithmeticSelfCpuKernelFunc>(); }
using ArithFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>>> arith_kernel_attr_list_map = {
@ -753,10 +730,6 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kSqrt,
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
[]() { return std::make_shared<SqrtMKLKernelFunc>(); }}}},
{kErf,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
@ -865,8 +838,6 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atanh,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAtanh); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Abs,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAbs); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sqrt,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSqrt); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Erf,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kErf); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Erfc,

View File

@ -15,10 +15,11 @@
*/
#include "plugin/device/cpu/kernel/mkldnn/eltwise_cpu_kernel.h"
#include <string>
#include <functional>
#include <unordered_map>
#include <memory>
#include <algorithm>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
@ -34,7 +35,7 @@ struct DescParam {
} // namespace
dnnl::eltwise_forward::desc EltWiseCpuKernelMod::GetForwardEltwiseDesc(const dnnl::memory::desc src_desc) {
static const std::unordered_map<std::string, DescParam> eltWiseOpDescMap{
static const std::unordered_map<std::string, DescParam> eltwise_op_desc_map{
{prim::kPrimRelu->name(), DescParam{dnnl::algorithm::eltwise_relu}},
{prim::kPrimRelu6->name(), DescParam{dnnl::algorithm::eltwise_clip, 0.0f, 6.0f}},
{prim::kPrimAbs->name(), DescParam{dnnl::algorithm::eltwise_abs}},
@ -47,33 +48,88 @@ dnnl::eltwise_forward::desc EltWiseCpuKernelMod::GetForwardEltwiseDesc(const dnn
{prim::kPrimSoftplus->name(), DescParam{dnnl::algorithm::eltwise_soft_relu}},
{prim::kPrimMish->name(), DescParam{dnnl::algorithm::eltwise_mish}},
};
const auto desc_pair = eltWiseOpDescMap.find(kernel_name_);
if (desc_pair == eltWiseOpDescMap.end()) {
MS_LOG(EXCEPTION) << "EltWiseCpuKernelMod does not support " << kernel_name_;
const auto desc_pair = eltwise_op_desc_map.find(kernel_name_);
if (desc_pair == eltwise_op_desc_map.end()) {
MS_LOG(EXCEPTION) << "For 'EltWise Op', it does not support " << kernel_name_;
}
auto desc = CreateDesc<dnnl::eltwise_forward::desc>(dnnl_forward_, desc_pair->second.algorithm, src_desc,
desc_pair->second.alpha, desc_pair->second.beta);
return desc;
}
void EltWiseCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (src_shape.empty()) {
(void)src_shape.insert(src_shape.begin(), 1);
bool EltWiseCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
auto iter = kernel_attr_map_.find(kernel_name_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
<< ", but got " << kernel_name_;
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = iter->second[index].second;
return true;
}
int EltWiseCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
src_shape_.clear();
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(src_shape_), LongToSize);
input_element_num_ = std::accumulate(src_shape_.begin(), src_shape_.end(), 1, std::multiplies<size_t>());
is_null_input_ = (input_element_num_ == 0);
if (is_null_input_) {
return KRET_OK;
}
TypeId input_type_id = inputs.at(kIndex0)->GetDtype();
auto dnnl_type_id = GetDnnlDataType(input_type_id);
if (dnnl_type_id == dnnl::memory::data_type::undef) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', Resize failed, dnnl do not support data type:" << TypeIdToString(input_type_id);
return KRET_RESIZE_FAILED;
}
if (src_shape_.empty()) {
(void)src_shape_.insert(src_shape_.begin(), 1);
}
dnnl::memory::desc src_desc = GetExactMemDesc(src_shape_, dnnl_type_id);
auto desc = GetForwardEltwiseDesc(src_desc);
auto prim_desc = CreateDesc<dnnl::eltwise_forward::primitive_desc>(desc, engine_);
primitive_ = CreatePrimitive<dnnl::eltwise_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, src_desc);
return KRET_OK;
}
bool EltWiseCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
std::vector<KernelAttr> EltWiseCpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_map_.find(kernel_name_);
if (iter == kernel_attr_map_.end()) {
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
<< ", but got " << kernel_name_;
return std::vector<KernelAttr>{};
}
std::vector<KernelAttr> support_list;
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, EltWiseFunc> &pair) { return pair.first; });
return support_list;
}
template <typename T>
bool EltWiseCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
@ -82,6 +138,39 @@ bool EltWiseCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
return true;
}
std::map<std::string, std::vector<std::pair<KernelAttr, EltWiseCpuKernelMod::EltWiseFunc>>>
EltWiseCpuKernelMod::kernel_attr_map_ = {
{kElu,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{kReLU,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{kReLU6,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{kExp,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{kLog,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{kSigmoid,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{kTanh,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{kSoftplus,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{prim::kPrimMish->name(),
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
{prim::kPrimSqrt->name(),
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EltWiseCpuKernelMod::LaunchKernel<float>}}}};
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Elu, []() { return std::make_shared<EltWiseCpuKernelMod>(kElu); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ReLU,
[]() { return std::make_shared<EltWiseCpuKernelMod>(kReLU); });
@ -97,5 +186,7 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Softplus,
[]() { return std::make_shared<EltWiseCpuKernelMod>(kSoftplus); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mish,
[]() { return std::make_shared<EltWiseCpuKernelMod>(prim::kPrimMish->name()); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sqrt,
[]() { return std::make_shared<EltWiseCpuKernelMod>(prim::kPrimSqrt->name()); });
} // namespace kernel
} // namespace mindspore

View File

@ -21,6 +21,7 @@
#include <vector>
#include <map>
#include <string>
#include <utility>
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
@ -34,43 +35,42 @@ constexpr auto kSigmoid = "Sigmoid";
constexpr auto kTanh = "Tanh";
constexpr auto kSoftplus = "Softplus";
constexpr auto kUnKnown = "UnKnown";
class EltWiseCpuKernelMod : public DeprecatedMKLCpuKernelMod {
class EltWiseCpuKernelMod : public MKLCpuKernelMod {
public:
EltWiseCpuKernelMod() = default;
explicit EltWiseCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
explicit EltWiseCpuKernelMod(const std::string &kernel_name) : kernel_name_(kernel_name) {}
~EltWiseCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
std::vector<KernelAttr> GetOpSupport() override {
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
{kElu, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kReLU, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kReLU6, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kExp, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kLog, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kSigmoid, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kTanh, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{kSoftplus, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{prim::kPrimMish->name(), {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
{prim::kPrimMish->name(), {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)}}};
auto iter = support_list_map.find(kernel_type_);
if (iter == support_list_map.end()) {
MS_LOG(EXCEPTION) << "Does not support " << kernel_type_ << "!";
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
if (is_null_input_) {
return true;
}
return iter->second;
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using EltWiseFunc = std::function<bool(EltWiseCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::map<std::string, std::vector<std::pair<KernelAttr, EltWiseCpuKernelMod::EltWiseFunc>>> kernel_attr_map_;
EltWiseFunc kernel_func_;
dnnl::eltwise_forward::desc GetForwardEltwiseDesc(const dnnl::memory::desc src_desc);
dnnl::prop_kind dnnl_forward_{dnnl::prop_kind::forward_training};
std::string kernel_type_{kUnKnown};
std::string kernel_name_{kUnKnown};
std::vector<size_t> src_shape_{};
size_t input_element_num_{0};
bool is_null_input_{false};
};
} // namespace kernel
} // namespace mindspore

View File

@ -69,20 +69,16 @@ bool LrnCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vect
int LrnCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
int ret = KRET_OK;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs)) != KRET_OK) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
TypeId ms_type_id = inputs.at(kIndex0)->GetDtype();
std::map<TypeId, dnnl::memory::data_type> dnnl_data_type_map = {{kNumberTypeFloat16, dnnl::memory::data_type::f16},
{kNumberTypeFloat32, dnnl::memory::data_type::f32}};
if (dnnl_data_type_map.find(ms_type_id) == dnnl_data_type_map.end()) {
auto dnnl_type_id = GetDnnlDataType(ms_type_id);
if (dnnl_type_id == dnnl::memory::data_type::undef) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "' LrnCpuKernelMod::Resize failed, dnnl do not support data type:" << TypeIdToString(ms_type_id);
<< "', LrnCpuKernelMod::Resize failed, dnnl do not support data type:" << TypeIdToString(ms_type_id);
return KRET_RESIZE_FAILED;
}
auto dnnl_type_id = dnnl_data_type_map[ms_type_id];
std::vector<size_t> input_shape_;
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
@ -96,7 +92,7 @@ int LrnCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
primitive_ = CreatePrimitive<dnnl::lrn_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_DST, src_desc);
return ret;
return KRET_OK;
}
template <typename T>

View File

@ -18,6 +18,7 @@
#include <vector>
#include <string>
#include <algorithm>
#include <map>
#include "utils/ms_utils.h"
#include "utils/profile.h"
@ -464,6 +465,21 @@ size_t MKLCpuKernelMod::GetSize(const dnnl::memory::desc &desc) const {
return size;
}
dnnl::memory::data_type MKLCpuKernelMod::GetDnnlDataType(TypeId ms_type_id) {
static const std::map<TypeId, dnnl::memory::data_type> dnnl_data_type_map = {
{kNumberTypeFloat16, dnnl::memory::data_type::f16},
{kNumberTypeFloat32, dnnl::memory::data_type::f32},
{kNumberTypeInt32, dnnl::memory::data_type::s32},
{kNumberTypeInt8, dnnl::memory::data_type::s8},
{kNumberTypeUInt8, dnnl::memory::data_type::u8}};
auto iter = dnnl_data_type_map.find(ms_type_id);
if (iter == dnnl_data_type_map.end()) {
MS_LOG(ERROR) << "Dnnl do not support data type:" << TypeIdToString(ms_type_id);
return dnnl::memory::data_type::undef;
}
return iter->second;
}
void MKLCpuKernelMod::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) {
MS_LOG(DEBUG) << "begin to invoke constructor of dnnl::reorder";
auto desc = dnnl::reorder(*src_mem, *dst_mem);

View File

@ -252,8 +252,8 @@ class MKLCpuKernelMod : public NativeCpuKernelMod {
return desc;
}
void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem);
size_t GetSize(const dnnl::memory::desc &desc) const;
dnnl::memory::data_type GetDnnlDataType(TypeId ms_type_id);
void SetDataHandle(dnnl::memory mem, void *ptr);
void *GetDataHandle(const dnnl::memory &mem) const;
std::unordered_map<int, dnnl::memory> arguments_;