forked from mindspore-Ecosystem/mindspore
!34556 Refator the EltWiseCpuKernel by new interfaces.
Merge pull request !34556 from liqiliang/eltwise-cpu
This commit is contained in:
commit
55f491ff46
|
@ -15,7 +15,6 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/arithmetic_self_cpu_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
|
@ -24,9 +23,7 @@
|
|||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/mkldnn/eltwise_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -620,26 +617,6 @@ void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressP
|
|||
func_pair->second(this, input, output, lens);
|
||||
}
|
||||
|
||||
// MKLDNN Sqrt
|
||||
class SqrtMKLKernelFunc : public CpuKernelFunc, private EltWiseCpuKernelMod {
|
||||
public:
|
||||
SqrtMKLKernelFunc() : EltWiseCpuKernelMod(kSqrt) {}
|
||||
~SqrtMKLKernelFunc() override = default;
|
||||
|
||||
void InitFunc(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name != kSqrt) {
|
||||
MS_LOG(EXCEPTION) << "Must be " << kSqrt << ", but got " << kernel_name;
|
||||
}
|
||||
EltWiseCpuKernelMod::InitKernel(kernel_node);
|
||||
}
|
||||
|
||||
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return EltWiseCpuKernelMod::Launch(inputs, workspace, outputs);
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<CpuKernelFunc> CreateArithSelfFunc() { return std::make_shared<ArithmeticSelfCpuKernelFunc>(); }
|
||||
using ArithFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>>> arith_kernel_attr_list_map = {
|
||||
|
@ -753,10 +730,6 @@ static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>
|
|||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kSqrt,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
[]() { return std::make_shared<SqrtMKLKernelFunc>(); }}}},
|
||||
{kErf,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
|
@ -865,8 +838,6 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atanh,
|
|||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAtanh); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Abs,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAbs); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sqrt,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSqrt); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Erf,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kErf); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Erfc,
|
||||
|
|
|
@ -15,10 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/mkldnn/eltwise_cpu_kernel.h"
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -34,7 +35,7 @@ struct DescParam {
|
|||
} // namespace
|
||||
|
||||
dnnl::eltwise_forward::desc EltWiseCpuKernelMod::GetForwardEltwiseDesc(const dnnl::memory::desc src_desc) {
|
||||
static const std::unordered_map<std::string, DescParam> eltWiseOpDescMap{
|
||||
static const std::unordered_map<std::string, DescParam> eltwise_op_desc_map{
|
||||
{prim::kPrimRelu->name(), DescParam{dnnl::algorithm::eltwise_relu}},
|
||||
{prim::kPrimRelu6->name(), DescParam{dnnl::algorithm::eltwise_clip, 0.0f, 6.0f}},
|
||||
{prim::kPrimAbs->name(), DescParam{dnnl::algorithm::eltwise_abs}},
|
||||
|
@ -47,33 +48,88 @@ dnnl::eltwise_forward::desc EltWiseCpuKernelMod::GetForwardEltwiseDesc(const dnn
|
|||
{prim::kPrimSoftplus->name(), DescParam{dnnl::algorithm::eltwise_soft_relu}},
|
||||
{prim::kPrimMish->name(), DescParam{dnnl::algorithm::eltwise_mish}},
|
||||
};
|
||||
const auto desc_pair = eltWiseOpDescMap.find(kernel_name_);
|
||||
if (desc_pair == eltWiseOpDescMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "EltWiseCpuKernelMod does not support " << kernel_name_;
|
||||
const auto desc_pair = eltwise_op_desc_map.find(kernel_name_);
|
||||
if (desc_pair == eltwise_op_desc_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'EltWise Op', it does not support " << kernel_name_;
|
||||
}
|
||||
auto desc = CreateDesc<dnnl::eltwise_forward::desc>(dnnl_forward_, desc_pair->second.algorithm, src_desc,
|
||||
desc_pair->second.alpha, desc_pair->second.beta);
|
||||
return desc;
|
||||
}
|
||||
|
||||
void EltWiseCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (src_shape.empty()) {
|
||||
(void)src_shape.insert(src_shape.begin(), 1);
|
||||
bool EltWiseCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
}
|
||||
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = iter->second[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int EltWiseCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
src_shape_.clear();
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(src_shape_), LongToSize);
|
||||
input_element_num_ = std::accumulate(src_shape_.begin(), src_shape_.end(), 1, std::multiplies<size_t>());
|
||||
is_null_input_ = (input_element_num_ == 0);
|
||||
if (is_null_input_) {
|
||||
return KRET_OK;
|
||||
}
|
||||
TypeId input_type_id = inputs.at(kIndex0)->GetDtype();
|
||||
auto dnnl_type_id = GetDnnlDataType(input_type_id);
|
||||
if (dnnl_type_id == dnnl::memory::data_type::undef) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', Resize failed, dnnl do not support data type:" << TypeIdToString(input_type_id);
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (src_shape_.empty()) {
|
||||
(void)src_shape_.insert(src_shape_.begin(), 1);
|
||||
}
|
||||
dnnl::memory::desc src_desc = GetExactMemDesc(src_shape_, dnnl_type_id);
|
||||
|
||||
auto desc = GetForwardEltwiseDesc(src_desc);
|
||||
auto prim_desc = CreateDesc<dnnl::eltwise_forward::primitive_desc>(desc, engine_);
|
||||
primitive_ = CreatePrimitive<dnnl::eltwise_forward>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DST, src_desc);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool EltWiseCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
std::vector<KernelAttr> EltWiseCpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_map_.find(kernel_name_);
|
||||
if (iter == kernel_attr_map_.end()) {
|
||||
MS_LOG(ERROR) << "For 'EltWise Op', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_)
|
||||
<< ", but got " << kernel_name_;
|
||||
return std::vector<KernelAttr>{};
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, EltWiseFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool EltWiseCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
|
||||
|
@ -82,6 +138,39 @@ bool EltWiseCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
return true;
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, EltWiseCpuKernelMod::EltWiseFunc>>>
|
||||
EltWiseCpuKernelMod::kernel_attr_map_ = {
|
||||
{kElu,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{kReLU,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{kReLU6,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{kExp,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{kLog,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{kSigmoid,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{kTanh,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{kSoftplus,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{prim::kPrimMish->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}},
|
||||
{prim::kPrimSqrt->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EltWiseCpuKernelMod::LaunchKernel<float>}}}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Elu, []() { return std::make_shared<EltWiseCpuKernelMod>(kElu); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ReLU,
|
||||
[]() { return std::make_shared<EltWiseCpuKernelMod>(kReLU); });
|
||||
|
@ -97,5 +186,7 @@ MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Softplus,
|
|||
[]() { return std::make_shared<EltWiseCpuKernelMod>(kSoftplus); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mish,
|
||||
[]() { return std::make_shared<EltWiseCpuKernelMod>(prim::kPrimMish->name()); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sqrt,
|
||||
[]() { return std::make_shared<EltWiseCpuKernelMod>(prim::kPrimSqrt->name()); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -34,43 +35,42 @@ constexpr auto kSigmoid = "Sigmoid";
|
|||
constexpr auto kTanh = "Tanh";
|
||||
constexpr auto kSoftplus = "Softplus";
|
||||
constexpr auto kUnKnown = "UnKnown";
|
||||
class EltWiseCpuKernelMod : public DeprecatedMKLCpuKernelMod {
|
||||
class EltWiseCpuKernelMod : public MKLCpuKernelMod {
|
||||
public:
|
||||
EltWiseCpuKernelMod() = default;
|
||||
explicit EltWiseCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
explicit EltWiseCpuKernelMod(const std::string &kernel_name) : kernel_name_(kernel_name) {}
|
||||
~EltWiseCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override {
|
||||
static std::map<std::string, std::vector<KernelAttr>> support_list_map = {
|
||||
{kElu, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{kReLU, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{kReLU6, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{kExp, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{kLog, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{kSigmoid, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{kTanh, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{kSoftplus, {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{prim::kPrimMish->name(), {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)}},
|
||||
{prim::kPrimMish->name(), {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)}}};
|
||||
auto iter = support_list_map.find(kernel_type_);
|
||||
if (iter == support_list_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Does not support " << kernel_type_ << "!";
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
return iter->second;
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using EltWiseFunc = std::function<bool(EltWiseCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, EltWiseCpuKernelMod::EltWiseFunc>>> kernel_attr_map_;
|
||||
EltWiseFunc kernel_func_;
|
||||
dnnl::eltwise_forward::desc GetForwardEltwiseDesc(const dnnl::memory::desc src_desc);
|
||||
|
||||
dnnl::prop_kind dnnl_forward_{dnnl::prop_kind::forward_training};
|
||||
|
||||
std::string kernel_type_{kUnKnown};
|
||||
std::string kernel_name_{kUnKnown};
|
||||
std::vector<size_t> src_shape_{};
|
||||
size_t input_element_num_{0};
|
||||
bool is_null_input_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -69,20 +69,16 @@ bool LrnCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vect
|
|||
int LrnCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
int ret = KRET_OK;
|
||||
if ((ret = KernelMod::Resize(base_operator, inputs, outputs)) != KRET_OK) {
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
TypeId ms_type_id = inputs.at(kIndex0)->GetDtype();
|
||||
std::map<TypeId, dnnl::memory::data_type> dnnl_data_type_map = {{kNumberTypeFloat16, dnnl::memory::data_type::f16},
|
||||
{kNumberTypeFloat32, dnnl::memory::data_type::f32}};
|
||||
|
||||
if (dnnl_data_type_map.find(ms_type_id) == dnnl_data_type_map.end()) {
|
||||
auto dnnl_type_id = GetDnnlDataType(ms_type_id);
|
||||
if (dnnl_type_id == dnnl::memory::data_type::undef) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "' LrnCpuKernelMod::Resize failed, dnnl do not support data type:" << TypeIdToString(ms_type_id);
|
||||
<< "', LrnCpuKernelMod::Resize failed, dnnl do not support data type:" << TypeIdToString(ms_type_id);
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
auto dnnl_type_id = dnnl_data_type_map[ms_type_id];
|
||||
std::vector<size_t> input_shape_;
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
|
||||
|
@ -96,7 +92,7 @@ int LrnCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vec
|
|||
primitive_ = CreatePrimitive<dnnl::lrn_forward>(prim_desc);
|
||||
AddArgument(DNNL_ARG_SRC, src_desc);
|
||||
AddArgument(DNNL_ARG_DST, src_desc);
|
||||
return ret;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/profile.h"
|
||||
|
||||
|
@ -464,6 +465,21 @@ size_t MKLCpuKernelMod::GetSize(const dnnl::memory::desc &desc) const {
|
|||
return size;
|
||||
}
|
||||
|
||||
dnnl::memory::data_type MKLCpuKernelMod::GetDnnlDataType(TypeId ms_type_id) {
|
||||
static const std::map<TypeId, dnnl::memory::data_type> dnnl_data_type_map = {
|
||||
{kNumberTypeFloat16, dnnl::memory::data_type::f16},
|
||||
{kNumberTypeFloat32, dnnl::memory::data_type::f32},
|
||||
{kNumberTypeInt32, dnnl::memory::data_type::s32},
|
||||
{kNumberTypeInt8, dnnl::memory::data_type::s8},
|
||||
{kNumberTypeUInt8, dnnl::memory::data_type::u8}};
|
||||
auto iter = dnnl_data_type_map.find(ms_type_id);
|
||||
if (iter == dnnl_data_type_map.end()) {
|
||||
MS_LOG(ERROR) << "Dnnl do not support data type:" << TypeIdToString(ms_type_id);
|
||||
return dnnl::memory::data_type::undef;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void MKLCpuKernelMod::Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem) {
|
||||
MS_LOG(DEBUG) << "begin to invoke constructor of dnnl::reorder";
|
||||
auto desc = dnnl::reorder(*src_mem, *dst_mem);
|
||||
|
|
|
@ -252,8 +252,8 @@ class MKLCpuKernelMod : public NativeCpuKernelMod {
|
|||
return desc;
|
||||
}
|
||||
void Reorder(dnnl::memory *src_mem, dnnl::memory *dst_mem);
|
||||
|
||||
size_t GetSize(const dnnl::memory::desc &desc) const;
|
||||
dnnl::memory::data_type GetDnnlDataType(TypeId ms_type_id);
|
||||
void SetDataHandle(dnnl::memory mem, void *ptr);
|
||||
void *GetDataHandle(const dnnl::memory &mem) const;
|
||||
std::unordered_map<int, dnnl::memory> arguments_;
|
||||
|
|
Loading…
Reference in New Issue