forked from mindspore-Ecosystem/mindspore
refactor cpu factory
This commit is contained in:
parent
8d2f098802
commit
db2931516e
|
@ -24,9 +24,10 @@
|
|||
#include "include/common/utils/context/graph_kernel_flags.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "runtime/device/kernel_runtime.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pass_manager.h"
|
||||
|
@ -375,11 +376,12 @@ void CPUSession::BuildKernel(const KernelGraph *kernel_graph) {
|
|||
continue;
|
||||
}
|
||||
std::shared_ptr<kernel::NativeCpuKernelMod> cpu_kernel_mod =
|
||||
kernel::NativeCpuKernelModFactory::GetInstance().Create(kernel_name, kernel_node);
|
||||
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Create(kernel_name);
|
||||
if (cpu_kernel_mod == nullptr) {
|
||||
KernelNotSupportException(kernel_node);
|
||||
}
|
||||
try {
|
||||
cpu_kernel_mod->SetCpuRefMapToKernelInfo(kernel_node);
|
||||
cpu_kernel_mod->Init(kernel_node);
|
||||
} catch (std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << e.what() << trace::DumpSourceLines(kernel_node);
|
||||
|
|
|
@ -30,21 +30,20 @@ namespace server {
|
|||
namespace kernel {
|
||||
using mindspore::kernel::SGDCpuKernelMod;
|
||||
template <typename T>
|
||||
|
||||
class SGDKernelMod : public SGDCpuKernelMod<T>, public OptimizerKernelMod {
|
||||
class SGDKernelMod : public SGDCpuKernelMod, public OptimizerKernelMod {
|
||||
public:
|
||||
SGDKernelMod() = default;
|
||||
~SGDKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &cnode) override {
|
||||
SGDCpuKernelMod<T>::InitKernel(cnode);
|
||||
SGDCpuKernelMod::InitKernel(cnode);
|
||||
InitServerKernelInputOutputSize(cnode);
|
||||
GenerateReuseKernelNodeInfo();
|
||||
}
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return SGDCpuKernelMod<T>::Launch(inputs, workspace, outputs);
|
||||
return SGDCpuKernelMod::Launch(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
void GenerateReuseKernelNodeInfo() override {
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ir/dtype/type.h"
|
||||
|
|
|
@ -543,16 +543,6 @@ bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &
|
|||
return true;
|
||||
}
|
||||
|
||||
int Sign(float x) {
|
||||
if (x > 0) {
|
||||
return 1;
|
||||
}
|
||||
if (x < 0) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
|
||||
const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list) {
|
||||
|
@ -1011,5 +1001,122 @@ size_t UnitSizeInBytes(const mindspore::TypeId &t) {
|
|||
|
||||
return bytes;
|
||||
}
|
||||
|
||||
KernelAttr &KernelAttr::AddInputAttr(const TypeId &ms_type, const std::string &format) {
|
||||
input_type_.emplace_back(ms_type, format);
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelAttr &KernelAttr::AddOutputAttr(const TypeId &ms_type, const std::string &format) {
|
||||
output_type_.emplace_back(ms_type, format);
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelAttr &KernelAttr::AddAllSameAttr(const bool &all_same) {
|
||||
all_same_ = all_same;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelAttr &KernelAttr::AddOutInRef(size_t output_index, size_t input_index) {
|
||||
out_in_ref_map_[output_index] = input_index;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void KernelAttr::SetInputAttrList(const std::vector<DataType> &addr_list) {
|
||||
input_type_.assign(addr_list.begin(), addr_list.end());
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &os, KernelAttr kernel_attr) {
|
||||
std::stringstream ss;
|
||||
ss << "[Kernel Attr] all same: " << kernel_attr.GetAllSame();
|
||||
size_t input_num = kernel_attr.GetInputSize();
|
||||
if (input_num > 0) {
|
||||
ss << ", input(";
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
ss << TypeIdLabel(kernel_attr.GetInputAttr(i).first);
|
||||
if (i != input_num - 1) {
|
||||
ss << ",";
|
||||
}
|
||||
}
|
||||
ss << ") ";
|
||||
}
|
||||
size_t output_num = kernel_attr.GetOutputSize();
|
||||
if (output_num > 0) {
|
||||
ss << ", output(";
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
ss << TypeIdLabel(kernel_attr.GetOutputAttr(i).first);
|
||||
if (i != output_num - 1) {
|
||||
ss << ",";
|
||||
}
|
||||
}
|
||||
ss << ").";
|
||||
}
|
||||
|
||||
return os << ss.str();
|
||||
}
|
||||
|
||||
std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
|
||||
const std::vector<KernelAttr> &kernel_attr_list) {
|
||||
// kernel_attr should not be all same. If so, then return false.
|
||||
if (kernel_attr.GetAllSame()) {
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
|
||||
auto input_num = kernel_attr.GetInputSize();
|
||||
auto output_num = kernel_attr.GetOutputSize();
|
||||
|
||||
for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
|
||||
const auto &cur_kernel_attr = kernel_attr_list[index];
|
||||
auto cur_input_num = cur_kernel_attr.GetInputSize();
|
||||
auto cur_output_num = cur_kernel_attr.GetOutputSize();
|
||||
if (!cur_kernel_attr.GetAllSame() && (input_num != cur_input_num || output_num != cur_output_num)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool mis_match = false;
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto dtype = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).first;
|
||||
auto format = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).second;
|
||||
if (kernel_attr.GetInputAttr(i).first != dtype || kernel_attr.GetInputAttr(i).second != format) {
|
||||
mis_match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (mis_match) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).first;
|
||||
auto format = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).second;
|
||||
if (kernel_attr.GetOutputAttr(i).first != dtype || kernel_attr.GetOutputAttr(i).second != format) {
|
||||
mis_match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!mis_match) {
|
||||
return std::make_pair(true, index);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
|
||||
KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info) {
|
||||
MS_EXCEPTION_IF_NULL(build_info);
|
||||
KernelAttr kernel_attr;
|
||||
for (size_t i = 0; i < build_info->GetInputNum(); i++) {
|
||||
kernel_attr.AddInputAttr(build_info->GetInputDeviceType(i), build_info->GetInputFormat(i));
|
||||
}
|
||||
for (size_t j = 0; j < build_info->GetOutputNum(); j++) {
|
||||
kernel_attr.AddOutputAttr(build_info->GetOutputDeviceType(j), build_info->GetOutputFormat(j));
|
||||
}
|
||||
return kernel_attr;
|
||||
}
|
||||
|
||||
KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node) {
|
||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
return GetKernelAttrFromBuildInfo(build_info);
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_
|
||||
|
||||
#include <dirent.h>
|
||||
#include <memory>
|
||||
|
@ -23,6 +23,7 @@
|
|||
#include <unordered_set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
@ -145,7 +146,6 @@ void SaveJsonInfo(const std::string &json_name, const std::string &info, const s
|
|||
std::string GetProcessor(const AnfNodePtr &anf_node);
|
||||
Processor GetProcessor(const string &processor);
|
||||
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
|
||||
int Sign(float x);
|
||||
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
|
||||
const std::vector<AnfNodePtr> &input_list,
|
||||
const std::vector<AnfNodePtr> &output_list);
|
||||
|
@ -235,6 +235,41 @@ size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int
|
|||
const std::vector<int64_t> &stop);
|
||||
size_t UnitSizeInBytes(const mindspore::TypeId &t);
|
||||
|
||||
class KernelAttr {
|
||||
public:
|
||||
using DataType = std::pair<TypeId, std::string>;
|
||||
KernelAttr() = default;
|
||||
~KernelAttr() = default;
|
||||
|
||||
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT);
|
||||
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT);
|
||||
KernelAttr &AddAllSameAttr(const bool &all_same);
|
||||
KernelAttr &AddOutInRef(size_t output_index, size_t input_index);
|
||||
|
||||
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
|
||||
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
|
||||
const bool &GetAllSame() const { return all_same_; }
|
||||
|
||||
size_t GetInputSize() const { return input_type_.size(); }
|
||||
size_t GetOutputSize() const { return output_type_.size(); }
|
||||
const OutputInputRefMap &GetOutInRefMap() const { return out_in_ref_map_; }
|
||||
|
||||
void SetInputAttrList(const std::vector<DataType> &addr_list);
|
||||
|
||||
private:
|
||||
std::vector<DataType> input_type_;
|
||||
std::vector<DataType> output_type_;
|
||||
bool all_same_{false};
|
||||
|
||||
// The map between kernel's output and input ref relationship.
|
||||
OutputInputRefMap out_in_ref_map_;
|
||||
};
|
||||
std::ostream &operator<<(std::ostream &os, KernelAttr kernel_attr);
|
||||
|
||||
std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr, const std::vector<KernelAttr> &attr_list);
|
||||
KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info);
|
||||
KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node);
|
||||
|
||||
#define CHECK_KERNEL_INPUTS_NUM(actual_inputs_num, expect_inputs_num, kernel_name) \
|
||||
do { \
|
||||
if ((actual_inputs_num) != (expect_inputs_num)) { \
|
||||
|
@ -261,4 +296,4 @@ size_t UnitSizeInBytes(const mindspore::TypeId &t);
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_
|
||||
|
|
|
@ -19,7 +19,8 @@
|
|||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "kernel/common_utils.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "kernel/oplib/opinfo.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
|
@ -83,7 +84,7 @@ void GetInputFormat(const CNodePtr &kernel_node, std::vector<std::string> *input
|
|||
}
|
||||
}
|
||||
|
||||
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr,
|
||||
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const kernel::KernelAttr &kernel_attr,
|
||||
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) {
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
|
@ -107,7 +108,7 @@ bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) {
|
|||
return false;
|
||||
}
|
||||
|
||||
int GetOutputDtypeMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &output_types) {
|
||||
int GetOutputDtypeMatchedNum(const kernel::KernelAttr &kernel_attr, const std::vector<TypeId> &output_types) {
|
||||
if (kernel_attr.GetOutputSize() != output_types.size()) {
|
||||
MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize()
|
||||
<< ", actual output num:" << output_types.size();
|
||||
|
@ -126,7 +127,7 @@ int GetOutputDtypeMatchedNum(const KernelAttr &kernel_attr, const std::vector<Ty
|
|||
return data_type_matched_num;
|
||||
}
|
||||
|
||||
int GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &input_types,
|
||||
int GetInputDtypeFormatMatchedNum(const kernel::KernelAttr &kernel_attr, const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
|
||||
if (kernel_attr.GetInputSize() != input_types.size()) {
|
||||
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
|
||||
|
@ -145,7 +146,7 @@ int GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, const std::vect
|
|||
return data_type_matched_num;
|
||||
}
|
||||
|
||||
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
||||
void ExpandKernelAttr(const CNodePtr &kernel_node, kernel::KernelAttr *kernel_attr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_attr);
|
||||
size_t attr_num = kernel_attr->GetInputSize();
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -169,7 +170,7 @@ void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
|||
TypeId output_dtype = kernel_attr->GetOutputAttr(0).first;
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t i = 1; i < output_num; ++i) {
|
||||
kernel_attr->AddOutputAttr(output_dtype);
|
||||
(void)kernel_attr->AddOutputAttr(output_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -218,7 +219,7 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
|
|||
MS_EXCEPTION(TypeError) << operator_info.str() << trace::DumpSourceLines(kernel_node);
|
||||
}
|
||||
|
||||
void UpdateDynamicKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {
|
||||
void UpdateDynamicKernelBuildInfo(const CNodePtr &kernel_node) {
|
||||
const std::string &op_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
MS_LOG(INFO) << "Operator name: " << op_name;
|
||||
// Set kernel build info
|
||||
|
@ -232,35 +233,9 @@ void UpdateDynamicKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {
|
|||
std::vector<std::string> output_formats;
|
||||
GetOutputFormat(kernel_node, &output_formats);
|
||||
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
|
||||
|
||||
// Set kernel attrs
|
||||
KernelAttr attr;
|
||||
for (size_t i = 0; i < input_types.size(); i++) {
|
||||
attr.AddInputAttr(input_types[i]);
|
||||
}
|
||||
for (size_t j = 0; j < output_types.size(); j++) {
|
||||
attr.AddInputAttr(output_types[j]);
|
||||
}
|
||||
std::vector<KernelAttr> kernel_attrs = kernel::NativeCpuKernelModFactory::GetInstance().GetSupportedKernelAttrList(
|
||||
common::AnfAlgo::GetCNodeName(kernel_node));
|
||||
kernel_attrs.emplace_back(attr);
|
||||
kernel::NativeCpuKernelModFactory::GetInstance().UpdateKernelAttrs(op_name, kernel_attrs);
|
||||
return;
|
||||
}
|
||||
|
||||
void AddKernelAttr(const CNodePtr &kernel_node, const KernelAttr &kernel_attr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<KernelAttr> kernel_attrs = kernel::NativeCpuKernelModFactory::GetInstance().GetSupportedKernelAttrList(
|
||||
common::AnfAlgo::GetCNodeName(kernel_node));
|
||||
if (kernel_attrs.size() == 1 && kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
|
||||
kernel_attrs.clear();
|
||||
}
|
||||
kernel_attrs.emplace_back(kernel_attr);
|
||||
kernel::NativeCpuKernelModFactory::GetInstance().UpdateKernelAttrs(common::AnfAlgo::GetCNodeName(kernel_node),
|
||||
kernel_attrs);
|
||||
}
|
||||
|
||||
void UpdateCustomKernelBuildInfoAndAttrs(const CNodePtr &kernel_node, bool is_akg_op) {
|
||||
void UpdateCustomKernelBuildInfo(const CNodePtr &kernel_node, bool is_akg_op) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
if (is_akg_op) {
|
||||
|
@ -285,29 +260,11 @@ void UpdateCustomKernelBuildInfoAndAttrs(const CNodePtr &kernel_node, bool is_ak
|
|||
builder->SetOutputsDeviceType(output_types);
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||
if (!is_akg_op) {
|
||||
// Update kernel attrs
|
||||
KernelAttr attr;
|
||||
if (input_types.size() != input_formats.size()) {
|
||||
MS_LOG(EXCEPTION) << "input types size " << input_types.size() << " is not equal to input formats size "
|
||||
<< input_formats.size() << " in op: " << common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
}
|
||||
for (size_t i = 0; i < input_types.size(); ++i) {
|
||||
attr.AddInputAttr(input_types[i], input_formats[i]);
|
||||
}
|
||||
if (output_types.size() != output_formats.size()) {
|
||||
MS_LOG(EXCEPTION) << "output types size " << output_types.size() << " is not equal to output formats size "
|
||||
<< output_formats.size() << " in op: " << common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
}
|
||||
for (size_t i = 0; i < output_types.size(); ++i) {
|
||||
attr.AddOutputAttr(output_types[i], output_formats[i]);
|
||||
}
|
||||
AddKernelAttr(kernel_node, attr);
|
||||
}
|
||||
}
|
||||
|
||||
KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<TypeId> &input_types,
|
||||
const std::vector<TypeId> &output_types, const KernelAttr &kernel_attr) {
|
||||
kernel::KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<TypeId> &input_types,
|
||||
const std::vector<TypeId> &output_types,
|
||||
const kernel::KernelAttr &kernel_attr) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
// Only process Custom op
|
||||
if (!IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
|
||||
|
@ -320,17 +277,14 @@ KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<T
|
|||
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetOutputSize() << ", actual input num:" << output_num;
|
||||
return kernel_attr;
|
||||
}
|
||||
KernelAttr result;
|
||||
bool has_none = false;
|
||||
kernel::KernelAttr result;
|
||||
// Fill inputs info.
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto type_format = kernel_attr.GetInputAttr(i);
|
||||
if (type_format.first == TypeId::kMetaTypeNone) {
|
||||
has_none = true;
|
||||
type_format.first = input_types[i];
|
||||
}
|
||||
if (type_format.second.empty()) {
|
||||
has_none = true;
|
||||
type_format.second = kOpFormat_DEFAULT;
|
||||
}
|
||||
result.AddInputAttr(type_format.first, type_format.second);
|
||||
|
@ -339,18 +293,13 @@ KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<T
|
|||
for (size_t i = 0; i < output_num; ++i) {
|
||||
auto type_format = kernel_attr.GetOutputAttr(i);
|
||||
if (type_format.first == TypeId::kMetaTypeNone) {
|
||||
has_none = true;
|
||||
type_format.first = output_types[i];
|
||||
}
|
||||
if (type_format.second.empty()) {
|
||||
has_none = true;
|
||||
type_format.second = kOpFormat_DEFAULT;
|
||||
}
|
||||
result.AddOutputAttr(type_format.first, type_format.second);
|
||||
}
|
||||
if (has_none) {
|
||||
AddKernelAttr(kernel_node, result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -374,8 +323,8 @@ bool IsDynamicParamKernel(const std::string &op_name) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
||||
const std::vector<KernelAttr> &kernel_attrs, const std::vector<TypeId> &input_types,
|
||||
bool SelectKernel(const CNodePtr &kernel_node, kernel::KernelAttr *selected_kernel_attr,
|
||||
const std::vector<kernel::KernelAttr> &kernel_attrs, const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes, const std::vector<TypeId> &output_types,
|
||||
std::pair<bool, bool> *matched, bool strict) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
|
@ -391,6 +340,7 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
|||
MS_LOG(DEBUG) << "Output num is not equal!";
|
||||
continue;
|
||||
}
|
||||
|
||||
auto new_kernel_attr = FillNoneInKernelAttr(kernel_node, input_types, output_types, kernel_attr);
|
||||
int input_dtype_matched_num =
|
||||
GetInputDtypeFormatMatchedNum(new_kernel_attr, input_types, input_not_cnode_indexes, strict);
|
||||
|
@ -412,17 +362,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
const std::string &op_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
|
||||
if (!kernel::NativeCpuKernelModFactory::GetInstance().SearchRegisteredOp(op_name)) {
|
||||
if (!kernel::Factory<kernel::NativeCpuKernelMod>::Instance().IsRegistered(op_name)) {
|
||||
auto tp = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrFuncType);
|
||||
if (tp == kCustomTypePyfunc) {
|
||||
kernel::NativeCpuKernelRegistrar(op_name, KernelAttr(),
|
||||
[]() { return std::make_shared<kernel::PyFuncCpuKernelMod>(); });
|
||||
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Register(
|
||||
op_name, []() { return std::make_shared<kernel::PyFuncCpuKernelMod>(); });
|
||||
} else if (tp == kCustomTypeAOT) {
|
||||
kernel::NativeCpuKernelRegistrar(op_name, KernelAttr(),
|
||||
[]() { return std::make_shared<kernel::CustomAOTCpuKernelMod>(); });
|
||||
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Register(
|
||||
op_name, []() { return std::make_shared<kernel::CustomAOTCpuKernelMod>(); });
|
||||
} else if (tp == kCustomTypeJULIA) {
|
||||
kernel::NativeCpuKernelRegistrar(op_name, KernelAttr(),
|
||||
[]() { return std::make_shared<kernel::CustomJULIACpuKernelMod>(); });
|
||||
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Register(
|
||||
op_name, []() { return std::make_shared<kernel::CustomJULIACpuKernelMod>(); });
|
||||
} else {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "Unsupported func type for Custom operator on CPU, it should be 'pyfunc' or 'aot' or 'julia', "
|
||||
|
@ -434,13 +384,15 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
MS_LOG(WARNING) << "Not find operator information for Custom operator[" << op_name << "]. "
|
||||
<< "Infer operator information from inputs. For more details, "
|
||||
<< "please refer to 'mindspore.ops.Custom' at https://www.mindspore.cn.";
|
||||
return UpdateCustomKernelBuildInfoAndAttrs(kernel_node, false);
|
||||
UpdateCustomKernelBuildInfo(kernel_node, false);
|
||||
return;
|
||||
}
|
||||
} else if (IsDynamicParamKernel(op_name)) {
|
||||
// Select for dynamic kernel(both the number and data type are undetermined).
|
||||
return UpdateDynamicKernelBuildInfoAndAttrs(kernel_node);
|
||||
UpdateDynamicKernelBuildInfo(kernel_node);
|
||||
return;
|
||||
} else if (IsAKGSparseOP(kernel_node)) {
|
||||
return UpdateCustomKernelBuildInfoAndAttrs(kernel_node, true);
|
||||
return UpdateCustomKernelBuildInfo(kernel_node, true);
|
||||
}
|
||||
|
||||
std::vector<std::string> input_formats;
|
||||
|
@ -450,23 +402,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|||
std::vector<TypeId> output_types;
|
||||
std::vector<TypeId> selected_output_types;
|
||||
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << op_name;
|
||||
auto kernel_attrs = kernel::NativeCpuKernelModFactory::GetInstance().GetSupportedKernelAttrList(
|
||||
common::AnfAlgo::GetCNodeName(kernel_node));
|
||||
if (kernel_attrs.empty() || (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0)) {
|
||||
MS_LOG(DEBUG) << "Operator[" << op_name << "] will get ops attr info.";
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
|
||||
if (op_info_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Not find op[" << op_name << "] in cpu. For more details, "
|
||||
<< "please refer to the list of supported cpu operations at https://www.mindspore.cn.";
|
||||
}
|
||||
kernel_attrs.clear();
|
||||
kernel::NativeCpuKernelModFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
||||
kernel::NativeCpuKernelModFactory::GetInstance().UpdateKernelAttrs(op_name, kernel_attrs);
|
||||
}
|
||||
GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes);
|
||||
GetOutputDtypes(kernel_node, &output_types);
|
||||
KernelAttr selected_kernel_attr;
|
||||
kernel::KernelAttr selected_kernel_attr;
|
||||
std::pair<bool, bool> matched = std::make_pair(false, false);
|
||||
auto kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(op_name);
|
||||
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes,
|
||||
output_types, &matched, true)) {
|
||||
if (op_name == "Cast") {
|
||||
|
|
|
@ -31,52 +31,6 @@ namespace cpu {
|
|||
using DataType = std::pair<TypeId, std::string>;
|
||||
|
||||
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
|
||||
// Indicate whether the kernel input/output number are variable.
|
||||
bool IsDynamicParamKernel(const std::string &op_name);
|
||||
|
||||
class KernelAttr {
|
||||
public:
|
||||
KernelAttr() : all_same_(0) {}
|
||||
~KernelAttr() = default;
|
||||
|
||||
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
|
||||
input_type_.emplace_back(ms_type, format);
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
|
||||
output_type_.emplace_back(ms_type, format);
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelAttr &SetAllSameAttr(bool all_same) {
|
||||
all_same_ = all_same;
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelAttr &AddOutInRef(size_t output_index, size_t input_index) {
|
||||
out_in_ref_map_[output_index] = input_index;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
|
||||
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
|
||||
bool GetAllSame() const { return all_same_; }
|
||||
void SetInputAttrList(const std::vector<DataType> &addr_list) {
|
||||
input_type_.assign(addr_list.begin(), addr_list.end());
|
||||
}
|
||||
|
||||
size_t GetInputSize() const { return input_type_.size(); }
|
||||
size_t GetOutputSize() const { return output_type_.size(); }
|
||||
const OutputInputRefMap &GetOutInRefMap() const { return out_in_ref_map_; }
|
||||
|
||||
private:
|
||||
std::vector<DataType> input_type_;
|
||||
std::vector<DataType> output_type_;
|
||||
// The map between kernel's output and input ref relationship.
|
||||
OutputInputRefMap out_in_ref_map_;
|
||||
bool all_same_;
|
||||
};
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,7 +19,8 @@
|
|||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_memory_manager.h"
|
||||
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
|
||||
#include "utils/trace_base.h"
|
||||
|
@ -220,11 +221,12 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
|||
}
|
||||
std::string kernel_name = common::AnfAlgo::GetCNodeName(node);
|
||||
std::shared_ptr<kernel::NativeCpuKernelMod> cpu_kernel =
|
||||
kernel::NativeCpuKernelModFactory::GetInstance().Create(kernel_name, node);
|
||||
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Create(kernel_name);
|
||||
if (!cpu_kernel) {
|
||||
MS_LOG(EXCEPTION) << "Build cpu operator[" << node->fullname_with_scope() << "] failed";
|
||||
}
|
||||
|
||||
cpu_kernel->SetCpuRefMapToKernelInfo(node);
|
||||
cpu_kernel->Init(node);
|
||||
AnfAlgo::SetKernelMod(cpu_kernel, node.get());
|
||||
}
|
||||
|
|
|
@ -161,5 +161,7 @@ bool AdamCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, con
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Adam, AdamCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -37,15 +37,11 @@ class AdamCpuKernelMod : public NativeCpuKernelMod {
|
|||
private:
|
||||
template <typename T>
|
||||
void LaunchAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
void LaunchAdamNnacl(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
bool use_nesterov_{false};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
enum input_list_ { VAR, M, V, BETA1_POWER, BETA2_POWER, LR, BETA1, BETA2, EPSILON, GRAD };
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Adam, KernelAttr(), AdamCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -149,5 +149,22 @@ bool AdamDeltaCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs
|
|||
LaunchAdamDelta<float>(delta, m, v, lr, beta1, beta2, epsilon, grad, lens);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> AdamDeltaCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)};
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdamNoUpdateParam, AdamDeltaCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -32,6 +32,9 @@ class AdamDeltaCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void CheckParams(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
|
||||
|
||||
|
@ -42,20 +45,6 @@ class AdamDeltaCpuKernelMod : public NativeCpuKernelMod {
|
|||
size_t elem_num_{0};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(AdamNoUpdateParam,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
AdamDeltaCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -144,5 +144,7 @@ bool AdamWeightDecayCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdamWeightDecay, AdamWeightDecayCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -41,8 +41,6 @@ class AdamWeightDecayCpuKernelMod : public NativeCpuKernelMod {
|
|||
TypeId dtype_{kTypeUnknown};
|
||||
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD };
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(AdamWeightDecay, KernelAttr(), AdamWeightDecayCpuKernelMod)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -45,5 +45,12 @@ bool AllGatherCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs
|
|||
auto input_data_num = inputs[0]->size / sizeof(float);
|
||||
return MPIAllGather(input_addr, output_addr, ranks_group_, input_data_num);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> AllGatherCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
|
||||
return kernel_attr_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, _HostAllGather, AllGatherCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -35,12 +35,12 @@ class AllGatherCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::vector<int> ranks_group_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
AllGatherCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -116,5 +116,25 @@ void ApplyAdagradCpuKernelMod::LaunchApplyAdagrad(T *var, T *accum, const T *lr,
|
|||
var[i] -= lr[0] * gradient[i] * (one / sqrt(accum[i] + eps));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ApplyAdagradCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)};
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ApplyAdagrad, ApplyAdagradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -35,6 +35,9 @@ class ApplyAdagradCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
|
||||
|
||||
|
@ -47,26 +50,6 @@ class ApplyAdagradCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool update_slots_{true};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(ApplyAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ApplyAdagradCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(ApplyAdagrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
ApplyAdagradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
|
|
|
@ -57,5 +57,28 @@ bool ApplyMomentumCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ApplyMomentumCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutInRef(0, 0),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutInRef(0, 0)};
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ApplyMomentum, ApplyMomentumCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -34,29 +34,10 @@ class ApplyMomentumCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutInRef(0, 0),
|
||||
ApplyMomentumCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutInRef(0, 0),
|
||||
ApplyMomentumCpuKernelMod);
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -57,37 +57,9 @@ bool check_validation(const std::vector<size_t> &shape, const size_t num_before_
|
|||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void ArgmaxCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t shape_len = shape_.size();
|
||||
if (shape_len == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
|
||||
}
|
||||
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
axis += SizeToLong(shape_len);
|
||||
if (axis < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
|
||||
<< "], but got " << axis;
|
||||
}
|
||||
axis = axis % SizeToLong(shape_len);
|
||||
num_before_axis_ = 1;
|
||||
num_after_axis_ = 1;
|
||||
for (size_t i = 0; i < shape_len; i++) {
|
||||
if (SizeToLong(i) < axis) {
|
||||
num_before_axis_ *= shape_[i];
|
||||
} else if (SizeToLong(i) > axis) {
|
||||
num_after_axis_ *= shape_[i];
|
||||
}
|
||||
}
|
||||
dim_axis_ = shape_[LongToSize(axis)];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ArgmaxCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (!check_validation<T>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -112,5 +84,50 @@ bool ArgmaxCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ArgmaxCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t shape_len = shape_.size();
|
||||
if (shape_len == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
|
||||
}
|
||||
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
axis += SizeToLong(shape_len);
|
||||
if (axis < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
|
||||
<< "], but got " << axis;
|
||||
}
|
||||
axis = axis % SizeToLong(shape_len);
|
||||
num_before_axis_ = 1;
|
||||
num_after_axis_ = 1;
|
||||
for (size_t i = 0; i < shape_len; i++) {
|
||||
if (SizeToLong(i) < axis) {
|
||||
num_before_axis_ *= shape_[i];
|
||||
} else if (SizeToLong(i) > axis) {
|
||||
num_after_axis_ *= shape_[i];
|
||||
}
|
||||
}
|
||||
dim_axis_ = shape_[LongToSize(axis)];
|
||||
|
||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
if (build_info->GetInputNum() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Argmax input size should not less than 1!";
|
||||
}
|
||||
auto input_type_id = build_info->GetInputDeviceType(0);
|
||||
switch (input_type_id) {
|
||||
case kNumberTypeFloat32:
|
||||
kernel_func_ = &ArgmaxCpuKernelMod::LaunchKernel<float>;
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
kernel_func_ = &ArgmaxCpuKernelMod::LaunchKernel<float16>;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Argmax kernel does not support " << TypeIdToString(input_type_id);
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Argmax, ArgmaxCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,11 +21,10 @@
|
|||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ArgmaxCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ArgmaxCpuKernelMod() = default;
|
||||
|
@ -34,17 +33,24 @@ class ArgmaxCpuKernelMod : public NativeCpuKernelMod {
|
|||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using ArgmaxFunc =
|
||||
std::function<bool(ArgmaxCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
ArgmaxFunc kernel_func_;
|
||||
|
||||
std::vector<size_t> shape_;
|
||||
size_t num_before_axis_{0};
|
||||
size_t num_after_axis_{0};
|
||||
size_t dim_axis_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCpuKernelMod, float16);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -61,34 +61,9 @@ bool check_validation(const std::vector<size_t> &shape, const size_t num_before_
|
|||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void ArgMaxWithValueCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t shape_len = shape_.size();
|
||||
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
axis += SizeToLong(shape_len);
|
||||
if (axis < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
|
||||
<< "], but got " << axis;
|
||||
}
|
||||
axis = axis % SizeToLong(shape_len);
|
||||
num_before_axis_ = 1;
|
||||
num_after_axis_ = 1;
|
||||
for (size_t i = 0; i < shape_len; i++) {
|
||||
if (SizeToLong(i) < axis) {
|
||||
num_before_axis_ *= shape_[i];
|
||||
} else if (SizeToLong(i) > axis) {
|
||||
num_after_axis_ *= shape_[i];
|
||||
}
|
||||
}
|
||||
dim_axis_ = shape_[LongToSize(axis)];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ArgMaxWithValueCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ArgMaxWithValueCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (!check_validation<T>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -116,5 +91,47 @@ bool ArgMaxWithValueCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ArgMaxWithValueCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t shape_len = shape_.size();
|
||||
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
axis += SizeToLong(shape_len);
|
||||
if (axis < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
|
||||
<< "], but got " << axis;
|
||||
}
|
||||
axis = axis % SizeToLong(shape_len);
|
||||
num_before_axis_ = 1;
|
||||
num_after_axis_ = 1;
|
||||
for (size_t i = 0; i < shape_len; i++) {
|
||||
if (SizeToLong(i) < axis) {
|
||||
num_before_axis_ *= shape_[i];
|
||||
} else if (SizeToLong(i) > axis) {
|
||||
num_after_axis_ *= shape_[i];
|
||||
}
|
||||
}
|
||||
dim_axis_ = shape_[LongToSize(axis)];
|
||||
|
||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
if (build_info->GetInputNum() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Argmax input size should not less than 1!";
|
||||
}
|
||||
auto input_type_id = build_info->GetInputDeviceType(0);
|
||||
switch (input_type_id) {
|
||||
case kNumberTypeFloat32:
|
||||
kernel_func_ = &ArgMaxWithValueCpuKernelMod::LaunchKernel<float>;
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
kernel_func_ = &ArgMaxWithValueCpuKernelMod::LaunchKernel<float16>;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Argmax kernel does not support " << TypeIdToString(input_type_id);
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ArgMaxWithValue, ArgMaxWithValueCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,11 +23,10 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ArgMaxWithValueCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ArgMaxWithValueCpuKernelMod() = default;
|
||||
|
@ -35,18 +34,25 @@ class ArgMaxWithValueCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using ArgMaxWithValueFunc =
|
||||
std::function<bool(ArgMaxWithValueCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
ArgMaxWithValueFunc kernel_func_;
|
||||
|
||||
std::vector<size_t> shape_;
|
||||
size_t num_before_axis_{0};
|
||||
size_t num_after_axis_{0};
|
||||
size_t dim_axis_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ArgMaxWithValue, KernelAttr(), ArgMaxWithValueCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(ArgMaxWithValue, KernelAttr(), ArgMaxWithValueCpuKernelMod, float16);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -62,37 +62,9 @@ bool check_validation(const std::vector<size_t> &shape, const size_t num_before_
|
|||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void ArgMinWithValueCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t shape_len = shape_.size();
|
||||
if (shape_len == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
|
||||
}
|
||||
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
axis += static_cast<int64_t>(shape_len);
|
||||
if (axis < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
|
||||
<< "], but got " << axis;
|
||||
}
|
||||
axis = axis % static_cast<int64_t>(shape_len);
|
||||
num_before_axis_ = 1;
|
||||
num_after_axis_ = 1;
|
||||
for (size_t i = 0; i < shape_len; i++) {
|
||||
if (static_cast<int64_t>(i) < axis) {
|
||||
num_before_axis_ *= shape_[i];
|
||||
} else if (static_cast<int64_t>(i) > axis) {
|
||||
num_after_axis_ *= shape_[i];
|
||||
}
|
||||
}
|
||||
dim_axis_ = shape_[axis];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ArgMinWithValueCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ArgMinWithValueCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (!check_validation<T>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -119,5 +91,50 @@ bool ArgMinWithValueCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ArgMinWithValueCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t shape_len = shape_.size();
|
||||
if (shape_len == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
|
||||
}
|
||||
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
axis += static_cast<int64_t>(shape_len);
|
||||
if (axis < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
|
||||
<< "], but got " << axis;
|
||||
}
|
||||
axis = axis % static_cast<int64_t>(shape_len);
|
||||
num_before_axis_ = 1;
|
||||
num_after_axis_ = 1;
|
||||
for (size_t i = 0; i < shape_len; i++) {
|
||||
if (static_cast<int64_t>(i) < axis) {
|
||||
num_before_axis_ *= shape_[i];
|
||||
} else if (static_cast<int64_t>(i) > axis) {
|
||||
num_after_axis_ *= shape_[i];
|
||||
}
|
||||
}
|
||||
dim_axis_ = shape_[axis];
|
||||
|
||||
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
|
||||
if (build_info->GetInputNum() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Argmax input size should not less than 1!";
|
||||
}
|
||||
auto input_type_id = build_info->GetInputDeviceType(0);
|
||||
switch (input_type_id) {
|
||||
case kNumberTypeFloat32:
|
||||
kernel_func_ = &ArgMinWithValueCpuKernelMod::LaunchKernel<float>;
|
||||
break;
|
||||
case kNumberTypeFloat16:
|
||||
kernel_func_ = &ArgMinWithValueCpuKernelMod::LaunchKernel<float16>;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Argmax kernel does not support " << TypeIdToString(input_type_id);
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ArgMinWithValue, ArgMinWithValueCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,11 +23,10 @@
|
|||
#include <algorithm>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ArgMinWithValueCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ArgMinWithValueCpuKernelMod() = default;
|
||||
|
@ -36,17 +35,24 @@ class ArgMinWithValueCpuKernelMod : public NativeCpuKernelMod {
|
|||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using ArgMinWithValueFunc =
|
||||
std::function<bool(ArgMinWithValueCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
ArgMinWithValueFunc kernel_func_;
|
||||
|
||||
std::vector<size_t> shape_;
|
||||
size_t num_before_axis_{0};
|
||||
size_t num_after_axis_{0};
|
||||
size_t dim_axis_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCpuKernelMod, float16);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -20,6 +20,10 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/nnacl/fp32/power_fp32.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/fp32/sub_fp32.h"
|
||||
|
@ -29,11 +33,14 @@
|
|||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kInputsNum = 2;
|
||||
constexpr size_t kOutputsNum = 1;
|
||||
constexpr float kMaxSubSerialSize = 10000.0;
|
||||
constexpr float kMaxPowSerialSize = 700.0;
|
||||
|
||||
constexpr auto kSub = "Sub";
|
||||
constexpr auto kMul = "Mul";
|
||||
constexpr auto kRealDiv = "RealDiv";
|
||||
constexpr auto kAssignAdd = "AssignAdd";
|
||||
|
||||
template <typename T>
|
||||
void ElementRealDiv(const T *input1, const T *input2, T *out, size_t size, size_t delta_1, size_t delta_2) {
|
||||
size_t idx_1 = 0;
|
||||
|
@ -59,10 +66,123 @@ void ElementRealDiv(const T *input1, const T *input2, T *out, size_t size, size_
|
|||
out[i] = dividend / divisor;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::AssignAdd(T *input1, const T *input2, T *out) {
|
||||
class ArithmeticCpuTypeFunc : public CpuKernelFunc {
|
||||
public:
|
||||
~ArithmeticCpuTypeFunc() override = default;
|
||||
explicit ArithmeticCpuTypeFunc(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_shape2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (output_shape_.size() == 0) {
|
||||
(void)output_shape_.insert(output_shape_.begin(), 1);
|
||||
}
|
||||
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < output_shape_.size(); ++i) {
|
||||
output_size_ *= output_shape_[i];
|
||||
}
|
||||
|
||||
op_para_.in_elements_num0_ = 1;
|
||||
for (size_t i = 0; i < input_shape1_.size(); ++i) {
|
||||
op_para_.in_elements_num0_ *= input_shape1_[i];
|
||||
}
|
||||
|
||||
op_para_.in_elements_num1_ = 1;
|
||||
for (size_t i = 0; i < input_shape2_.size(); ++i) {
|
||||
op_para_.in_elements_num1_ *= input_shape2_[i];
|
||||
}
|
||||
|
||||
size_t l = input_shape1_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape1_.insert(input_shape1_.begin(), 1);
|
||||
}
|
||||
l = input_shape2_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape2_.insert(input_shape2_.begin(), 1);
|
||||
}
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||
|
||||
InitComputeFunc();
|
||||
}
|
||||
|
||||
void Sub(const T *input1, const T *input2, T *out);
|
||||
void Add(const T *input1, const T *input2, T *out);
|
||||
void Mul(const T *input1, const T *input2, T *out);
|
||||
void RealDiv(const T *input1, const T *input2, T *out);
|
||||
void Div(const T *input1, const T *input2, T *out);
|
||||
void FloorDiv(const T *input1, const T *input2, T *out);
|
||||
void Mod(const T *input1, const T *input2, T *out);
|
||||
void FloorMod(const T *input1, const T *input2, T *out);
|
||||
void Pow(const T *input1, const T *input2, T *out);
|
||||
void AssignAdd(T *input1, const T *input2, T *out);
|
||||
void Atan2(const T *input1, const T *input2, T *out);
|
||||
void SquaredDifference(const T *input1, const T *input2, T *out);
|
||||
|
||||
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
if (output_size_ == 0) {
|
||||
MS_LOG(WARNING) << kernel_name_ << " output shape contain 0, output_shape: " << output_shape_;
|
||||
return true;
|
||||
}
|
||||
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
|
||||
AssignAdd(input1, input2, output);
|
||||
} else {
|
||||
compute_func_(this, input1, input2, output);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
void InitComputeFunc() {
|
||||
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
|
||||
return;
|
||||
}
|
||||
static const std::unordered_map<std::string, TypeComputeFunc> arithmeticMathFuncMap{
|
||||
{prim::kPrimAdd->name(), &ArithmeticCpuTypeFunc<T>::Add},
|
||||
{prim::kPrimSub->name(), &ArithmeticCpuTypeFunc<T>::Sub},
|
||||
{prim::kPrimMul->name(), &ArithmeticCpuTypeFunc<T>::Mul},
|
||||
{prim::kPrimDiv->name(), &ArithmeticCpuTypeFunc<T>::Div},
|
||||
{prim::kPrimMod->name(), &ArithmeticCpuTypeFunc<T>::Mod},
|
||||
{prim::kPrimFloorMod->name(), &ArithmeticCpuTypeFunc<T>::FloorMod},
|
||||
{prim::kPrimPow->name(), &ArithmeticCpuTypeFunc<T>::Pow},
|
||||
{prim::kPrimFloorDiv->name(), &ArithmeticCpuTypeFunc<T>::FloorDiv},
|
||||
{prim::kPrimAtan2->name(), &ArithmeticCpuTypeFunc<T>::Atan2},
|
||||
{prim::kPrimRealDiv->name(), &ArithmeticCpuTypeFunc<T>::RealDiv},
|
||||
{prim::kPrimSquaredDifference->name(), &ArithmeticCpuTypeFunc<T>::SquaredDifference}};
|
||||
if (arithmeticMathFuncMap.find(kernel_name_) == arithmeticMathFuncMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'Arithmetic', only supports operators in " << Unorderedmap2Str(arithmeticMathFuncMap)
|
||||
<< ", but got " << kernel_name_;
|
||||
}
|
||||
compute_func_ = arithmeticMathFuncMap.at(kernel_name_);
|
||||
}
|
||||
|
||||
std::string kernel_name_;
|
||||
|
||||
size_t output_size_{1};
|
||||
ArithmeticParameter op_para_{};
|
||||
|
||||
std::vector<size_t> input_shape1_;
|
||||
std::vector<size_t> input_shape2_;
|
||||
std::vector<size_t> input_element_num1_;
|
||||
std::vector<size_t> input_element_num2_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
|
||||
using TypeComputeFunc = std::function<void(ArithmeticCpuTypeFunc *, const T *in_x, const T *in_y, T *out)>;
|
||||
TypeComputeFunc compute_func_{nullptr};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuTypeFunc<T>::AssignAdd(T *input1, const T *input2, T *out) {
|
||||
auto task = [&input1, &input2, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = input1[i] + input2[i];
|
||||
|
@ -73,7 +193,7 @@ void ArithmeticCpuKernelMod<T>::AssignAdd(T *input1, const T *input2, T *out) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::Add(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::Add(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -87,7 +207,7 @@ void ArithmeticCpuKernelMod<T>::Add(const T *input1, const T *input2, T *out) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::Sub(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::Sub(const T *input1, const T *input2, T *out) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
if (input_shape1_ == input_shape2_) {
|
||||
auto task = [this, input1, input2, out](size_t start, size_t end) {
|
||||
|
@ -122,7 +242,7 @@ void ArithmeticCpuKernelMod<T>::Sub(const T *input1, const T *input2, T *out) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::Mul(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::Mul(const T *input1, const T *input2, T *out) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
if (input_shape1_ == input_shape2_) {
|
||||
auto task = [this, input1, input2, out](size_t start, size_t end) {
|
||||
|
@ -160,7 +280,7 @@ void ArithmeticCpuKernelMod<T>::Mul(const T *input1, const T *input2, T *out) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::RealDiv(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::RealDiv(const T *input1, const T *input2, T *out) {
|
||||
if (input_shape1_ == input_shape2_) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
ElementRealDiv<T>(input1 + start, input2 + start, out + start, end - start, 1, 1);
|
||||
|
@ -211,7 +331,7 @@ void ArithmeticCpuKernelMod<T>::RealDiv(const T *input1, const T *input2, T *out
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::Div(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::Div(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -240,7 +360,7 @@ void ArithmeticCpuKernelMod<T>::Div(const T *input1, const T *input2, T *out) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::FloorDiv(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::FloorDiv(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -269,7 +389,7 @@ void ArithmeticCpuKernelMod<T>::FloorDiv(const T *input1, const T *input2, T *ou
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::Mod(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::Mod(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -291,7 +411,7 @@ void ArithmeticCpuKernelMod<T>::Mod(const T *input1, const T *input2, T *out) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::FloorMod(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::FloorMod(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -308,7 +428,7 @@ void ArithmeticCpuKernelMod<T>::FloorMod(const T *input1, const T *input2, T *ou
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::Pow(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::Pow(const T *input1, const T *input2, T *out) {
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
auto is_power_single = [this]() {
|
||||
bool is_power_single = false;
|
||||
|
@ -365,7 +485,7 @@ void ArithmeticCpuKernelMod<T>::Pow(const T *input1, const T *input2, T *out) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::SquaredDifference(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::SquaredDifference(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -384,7 +504,7 @@ void ArithmeticCpuKernelMod<T>::SquaredDifference(const T *input1, const T *inpu
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::Atan2(const T *input1, const T *input2, T *out) {
|
||||
void ArithmeticCpuTypeFunc<T>::Atan2(const T *input1, const T *input2, T *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -399,87 +519,172 @@ void ArithmeticCpuKernelMod<T>::Atan2(const T *input1, const T *input2, T *out)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::InitComputeFunc() {
|
||||
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
|
||||
return;
|
||||
}
|
||||
static const std::unordered_map<std::string, TypeComputeFunc> arithmeticMathFuncMap{
|
||||
{prim::kPrimAdd->name(), &ArithmeticCpuKernelMod<T>::Add},
|
||||
{prim::kPrimSub->name(), &ArithmeticCpuKernelMod<T>::Sub},
|
||||
{prim::kPrimMul->name(), &ArithmeticCpuKernelMod<T>::Mul},
|
||||
{prim::kPrimDiv->name(), &ArithmeticCpuKernelMod<T>::Div},
|
||||
{prim::kPrimMod->name(), &ArithmeticCpuKernelMod<T>::Mod},
|
||||
{prim::kPrimFloorMod->name(), &ArithmeticCpuKernelMod<T>::FloorMod},
|
||||
{prim::kPrimPow->name(), &ArithmeticCpuKernelMod<T>::Pow},
|
||||
{prim::kPrimFloorDiv->name(), &ArithmeticCpuKernelMod<T>::FloorDiv},
|
||||
{prim::kPrimAtan2->name(), &ArithmeticCpuKernelMod<T>::Atan2},
|
||||
{prim::kPrimRealDiv->name(), &ArithmeticCpuKernelMod<T>::RealDiv},
|
||||
{prim::kPrimSquaredDifference->name(), &ArithmeticCpuKernelMod<T>::SquaredDifference}};
|
||||
if (arithmeticMathFuncMap.find(kernel_name_) == arithmeticMathFuncMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'Arithmetic', only supports operators in " << Unorderedmap2Str(arithmeticMathFuncMap)
|
||||
<< ", but got " << kernel_name_;
|
||||
}
|
||||
compute_func_ = arithmeticMathFuncMap.at(kernel_name_);
|
||||
std::shared_ptr<CpuKernelFunc> SpecializeArithFunc(const CNodePtr &kernel_node) {
|
||||
return std::make_shared<ArithmeticCpuTypeFunc<T>>(kernel_node);
|
||||
}
|
||||
using ArithmeticCpuFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>(const CNodePtr &)>;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFuncCreator>>> kernel_attr_list = {
|
||||
{kSub,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{kMul,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
SpecializeArithFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithFunc<bool>}}},
|
||||
{prim::kPrimDiv->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{prim::kPrimPow->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{kRealDiv,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SpecializeArithFunc<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{prim::kPrimFloorDiv->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
SpecializeArithFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
SpecializeArithFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
SpecializeArithFunc<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SpecializeArithFunc<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{prim::kPrimMod->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>}}},
|
||||
{prim::kPrimFloorMod->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SpecializeArithFunc<float16>}}},
|
||||
{kAssignAdd,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
SpecializeArithFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{prim::kPrimSquaredDifference->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SpecializeArithFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SpecializeArithFunc<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}},
|
||||
{prim::kPrimAtan2->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SpecializeArithFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeArithFunc<double>}}}};
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
void ArithmeticCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_shape2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (output_shape_.size() == 0) {
|
||||
(void)output_shape_.insert(output_shape_.begin(), 1);
|
||||
if (kernel_name_ != kernel_type_) {
|
||||
MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name_;
|
||||
}
|
||||
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < output_shape_.size(); ++i) {
|
||||
output_size_ *= output_shape_[i];
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Arithmetic does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
op_para_.in_elements_num0_ = 1;
|
||||
for (size_t i = 0; i < input_shape1_.size(); ++i) {
|
||||
op_para_.in_elements_num0_ *= input_shape1_[i];
|
||||
}
|
||||
|
||||
op_para_.in_elements_num1_ = 1;
|
||||
for (size_t i = 0; i < input_shape2_.size(); ++i) {
|
||||
op_para_.in_elements_num1_ *= input_shape2_[i];
|
||||
}
|
||||
|
||||
size_t l = input_shape1_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape1_.insert(input_shape1_.begin(), 1);
|
||||
}
|
||||
l = input_shape2_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape2_.insert(input_shape2_.begin(), 1);
|
||||
}
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||
InitComputeFunc();
|
||||
func_obj_ = kernel_attr_list[kernel_name_][index].second(kernel_node);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ArithmeticCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
if (output_size_ == 0) {
|
||||
MS_LOG(WARNING) << kernel_name_ << " output shape contain 0, output_shape: " << output_shape_;
|
||||
return true;
|
||||
std::vector<KernelAttr> ArithmeticCpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_list.find(kernel_type_);
|
||||
if (iter == kernel_attr_list.end()) {
|
||||
MS_LOG(EXCEPTION) << "Arithmetic cpu does not support " << kernel_type_;
|
||||
}
|
||||
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
|
||||
AssignAdd(input1, input2, output);
|
||||
} else {
|
||||
compute_func_(this, input1, input2, output);
|
||||
}
|
||||
return true;
|
||||
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ArithmeticCpuFuncCreator> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sub,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kSub); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mul,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kMul); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Div,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimDiv->name()); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Pow,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimPow->name()); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RealDiv,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kRealDiv); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorDiv, []() {
|
||||
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimFloorDiv->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mod,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimMod->name()); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorMod, []() {
|
||||
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimFloorMod->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AssignAdd,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAssignAdd); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SquaredDifference, []() {
|
||||
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimSquaredDifference->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atan2,
|
||||
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimAtan2->name()); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,216 +18,39 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/arithmetic.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ArithmeticCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ArithmeticCpuKernelMod() = default;
|
||||
explicit ArithmeticCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~ArithmeticCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
const size_t kInputsNum = 2;
|
||||
const size_t kOutputsNum = 1;
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitComputeFunc();
|
||||
void Sub(const T *input1, const T *input2, T *out);
|
||||
void Add(const T *input1, const T *input2, T *out);
|
||||
void Mul(const T *input1, const T *input2, T *out);
|
||||
void RealDiv(const T *input1, const T *input2, T *out);
|
||||
void Div(const T *input1, const T *input2, T *out);
|
||||
void FloorDiv(const T *input1, const T *input2, T *out);
|
||||
void Mod(const T *input1, const T *input2, T *out);
|
||||
void FloorMod(const T *input1, const T *input2, T *out);
|
||||
void Pow(const T *input1, const T *input2, T *out);
|
||||
void AssignAdd(T *input1, const T *input2, T *out);
|
||||
void Atan2(const T *input1, const T *input2, T *out);
|
||||
void SquaredDifference(const T *input1, const T *input2, T *out);
|
||||
|
||||
using TypeComputeFunc = std::function<void(ArithmeticCpuKernelMod *, const T *in_x, const T *in_y, T *out)>;
|
||||
TypeComputeFunc compute_func_{nullptr};
|
||||
size_t output_size_{1};
|
||||
ArithmeticParameter op_para_{};
|
||||
|
||||
std::vector<size_t> input_shape1_;
|
||||
std::vector<size_t> input_shape2_;
|
||||
std::vector<size_t> input_element_num1_;
|
||||
std::vector<size_t> input_element_num2_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
std::shared_ptr<CpuKernelFunc> func_obj_;
|
||||
std::string kernel_type_{"Unknown"};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
ArithmeticCpuKernelMod, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticCpuKernelMod, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Pow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Pow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
RealDiv,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArithmeticCpuKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
RealDiv,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
RealDiv,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
ArithmeticCpuKernelMod, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArithmeticCpuKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
ArithmeticCpuKernelMod, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorDiv,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
ArithmeticCpuKernelMod, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Mod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorMod,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
FloorMod,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArithmeticCpuKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
AssignAdd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
AssignAdd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
SquaredDifference,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
SquaredDifference,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ArithmeticCpuKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
SquaredDifference,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
SquaredDifference,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Atan2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Atan2,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticCpuKernelMod, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
|
@ -27,13 +30,104 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kMaxLessSerialSize = 15000;
|
||||
constexpr size_t kInputsNum = 2;
|
||||
constexpr size_t kOutputsNum = 1;
|
||||
} // namespace
|
||||
constexpr auto kEqual = "Equal";
|
||||
constexpr auto kNotEqual = "NotEqual";
|
||||
|
||||
template <typename T>
|
||||
class ArithLogicCpuTypeFunc : public CpuKernelFunc {
|
||||
public:
|
||||
ArithLogicCpuTypeFunc() = default;
|
||||
~ArithLogicCpuTypeFunc() override = default;
|
||||
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
const auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
|
||||
compute_func_(this, input1, input2, output);
|
||||
return true;
|
||||
}
|
||||
|
||||
void InitFunc(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_shape2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (output_shape_.empty()) {
|
||||
(void)output_shape_.insert(output_shape_.begin(), 1);
|
||||
}
|
||||
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < output_shape_.size(); ++i) {
|
||||
output_size_ *= output_shape_[i];
|
||||
}
|
||||
|
||||
size_t l = input_shape1_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape1_.insert(input_shape1_.begin(), 1);
|
||||
}
|
||||
l = input_shape2_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape2_.insert(input_shape2_.begin(), 1);
|
||||
}
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
auto dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
if (dtype_ != dtype_1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the 'input1' and 'input2' should have the same data type, but got type of 'input1': "
|
||||
<< dtype_ << ", and the type of 'input2': " << dtype_1;
|
||||
}
|
||||
static const std::unordered_map<std::string, TypeComputeFunc> arithmetic_logic_func_map{
|
||||
{prim::kPrimGreater->name(), &ArithLogicCpuTypeFunc<T>::Greater},
|
||||
{prim::kPrimGreaterEqual->name(), &ArithLogicCpuTypeFunc<T>::GreaterEqual},
|
||||
{prim::kPrimLogicalAnd->name(), &ArithLogicCpuTypeFunc<T>::LogicalAnd},
|
||||
{prim::kPrimLessEqual->name(), &ArithLogicCpuTypeFunc<T>::LessEqual},
|
||||
{prim::kPrimLogicalOr->name(), &ArithLogicCpuTypeFunc<T>::LogicalOr},
|
||||
{prim::kPrimLogicalXor->name(), &ArithLogicCpuTypeFunc<T>::LogicalXor},
|
||||
{prim::kPrimLess->name(), &ArithLogicCpuTypeFunc<T>::Less},
|
||||
{prim::kPrimNotEqual->name(), &ArithLogicCpuTypeFunc<T>::NotEqual},
|
||||
{prim::kPrimEqual->name(), &ArithLogicCpuTypeFunc<T>::Equal}};
|
||||
if (arithmetic_logic_func_map.find(kernel_name_) == arithmetic_logic_func_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'ArithmeticLogic', only supports operators in "
|
||||
<< Unorderedmap2Str(arithmetic_logic_func_map) << ", but got " << kernel_name_;
|
||||
}
|
||||
compute_func_ = arithmetic_logic_func_map.at(kernel_name_);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Op>
|
||||
void BinaryOp(const T *input1, const T *input2, bool *out, Op op);
|
||||
|
||||
void Less(const T *input1, const T *input2, bool *out);
|
||||
void Equal(const T *input1, const T *input2, bool *out);
|
||||
void NotEqual(const T *input1, const T *input2, bool *out);
|
||||
void Greater(const T *input1, const T *input2, bool *out);
|
||||
void GreaterEqual(const T *input1, const T *input2, bool *out);
|
||||
void LessEqual(const T *input1, const T *input2, bool *out);
|
||||
void LogicalAnd(const T *input1, const T *input2, bool *out);
|
||||
void LogicalOr(const T *input1, const T *input2, bool *out);
|
||||
void LogicalXor(const T *input1, const T *input2, bool *out);
|
||||
|
||||
using TypeComputeFunc = std::function<void(ArithLogicCpuTypeFunc *, const T *, const T *, bool *)>;
|
||||
TypeComputeFunc compute_func_{nullptr};
|
||||
|
||||
std::string kernel_name_;
|
||||
size_t output_size_{1};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
|
||||
std::vector<size_t> input_shape1_;
|
||||
std::vector<size_t> input_shape2_;
|
||||
std::vector<size_t> input_element_num1_;
|
||||
std::vector<size_t> input_element_num2_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename Op>
|
||||
void ArithmeticLogicCpuKernelMod<T>::BinaryOp(const T *input1, const T *input2, bool *out, Op op) {
|
||||
void ArithLogicCpuTypeFunc<T>::BinaryOp(const T *input1, const T *input2, bool *out, Op op) {
|
||||
size_t input1_size = 1;
|
||||
size_t input2_size = 2;
|
||||
|
||||
|
@ -80,32 +174,32 @@ void ArithmeticLogicCpuKernelMod<T>::BinaryOp(const T *input1, const T *input2,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::Less(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::Less(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::less<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::Equal(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::Equal(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::equal_to<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::NotEqual(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::NotEqual(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::not_equal_to<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::LogicalAnd(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::LogicalAnd(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::logical_and<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::LogicalOr(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::LogicalOr(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::logical_or<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
|
||||
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
|
||||
auto task = [input1, input2, out, &base_iter](size_t start, size_t end) {
|
||||
auto iter = base_iter;
|
||||
|
@ -119,87 +213,172 @@ void ArithmeticLogicCpuKernelMod<T>::LogicalXor(const T *input1, const T *input2
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::Greater(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::Greater(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::greater<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::GreaterEqual(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::GreaterEqual(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::greater_equal<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::LessEqual(const T *input1, const T *input2, bool *out) {
|
||||
void ArithLogicCpuTypeFunc<T>::LessEqual(const T *input1, const T *input2, bool *out) {
|
||||
BinaryOp(input1, input2, out, std::less_equal<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::InitComputeFunc() {
|
||||
static const std::unordered_map<std::string, TypeComputeFunc> arithmeticLogicFuncMap{
|
||||
{prim::kPrimGreater->name(), &ArithmeticLogicCpuKernelMod<T>::Greater},
|
||||
{prim::kPrimGreaterEqual->name(), &ArithmeticLogicCpuKernelMod<T>::GreaterEqual},
|
||||
{prim::kPrimLogicalAnd->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalAnd},
|
||||
{prim::kPrimLessEqual->name(), &ArithmeticLogicCpuKernelMod<T>::LessEqual},
|
||||
{prim::kPrimLogicalOr->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalOr},
|
||||
{prim::kPrimLogicalXor->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalXor},
|
||||
{prim::kPrimLess->name(), &ArithmeticLogicCpuKernelMod<T>::Less},
|
||||
{prim::kPrimNotEqual->name(), &ArithmeticLogicCpuKernelMod<T>::NotEqual},
|
||||
{prim::kPrimEqual->name(), &ArithmeticLogicCpuKernelMod<T>::Equal}};
|
||||
if (arithmeticLogicFuncMap.find(kernel_name_) == arithmeticLogicFuncMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "For 'ArithmeticLogic', only supports operators in "
|
||||
<< Unorderedmap2Str(arithmeticLogicFuncMap) << ", but got " << kernel_name_;
|
||||
std::shared_ptr<CpuKernelFunc> SpecializeArithLogFunc() {
|
||||
return std::make_shared<ArithLogicCpuTypeFunc<T>>();
|
||||
}
|
||||
using ArithLogicCpuFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFuncCreator>>> kernel_attr_lists = {
|
||||
{prim::kPrimLess->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>}}},
|
||||
{kEqual,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>}}},
|
||||
{kNotEqual,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>}}},
|
||||
{prim::kPrimGreater->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int64_t>}}},
|
||||
{prim::kPrimGreaterEqual->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>}}},
|
||||
{prim::kPrimLessEqual->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<double>}}},
|
||||
{prim::kPrimLogicalAnd->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>}}},
|
||||
{prim::kPrimLogicalOr->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>}}},
|
||||
{prim::kPrimLogicalXor->name(),
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SpecializeArithLogFunc<bool>}}}};
|
||||
} // namespace
|
||||
|
||||
void ArithmeticLogicCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name != kernel_type_) {
|
||||
MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name;
|
||||
}
|
||||
compute_func_ = arithmeticLogicFuncMap.at(kernel_name_);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Arithmetic logic does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
func_obj_ = kernel_attr_lists[kernel_name][index].second();
|
||||
func_obj_->InitFunc(kernel_node);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticLogicCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_shape2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (output_shape_.empty()) {
|
||||
(void)output_shape_.insert(output_shape_.begin(), 1);
|
||||
std::vector<KernelAttr> ArithmeticLogicCpuKernelMod::GetOpSupport() {
|
||||
auto iter = kernel_attr_lists.find(kernel_type_);
|
||||
if (iter == kernel_attr_lists.end()) {
|
||||
MS_LOG(EXCEPTION) << "Arithmetic logic cpu does not support " << kernel_type_;
|
||||
}
|
||||
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < output_shape_.size(); ++i) {
|
||||
output_size_ *= output_shape_[i];
|
||||
}
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ArithLogicCpuFuncCreator> &pair) { return pair.first; });
|
||||
|
||||
size_t l = input_shape1_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape1_.insert(input_shape1_.begin(), 1);
|
||||
}
|
||||
l = input_shape2_.size();
|
||||
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
|
||||
(void)input_shape2_.insert(input_shape2_.begin(), 1);
|
||||
}
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
auto dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
|
||||
if (dtype_ != dtype_1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the 'input1' and 'input2' should have the same data type, but got type of 'input1': "
|
||||
<< dtype_ << ", and the type of 'input2': " << dtype_1;
|
||||
}
|
||||
InitComputeFunc();
|
||||
return support_list;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ArithmeticLogicCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> & /* workspace */,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
const auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
|
||||
compute_func_(this, input1, input2, output);
|
||||
return true;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Less, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLess->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Equal,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kEqual); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, NotEqual,
|
||||
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kNotEqual); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Greater, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimGreater->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, GreaterEqual, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimGreaterEqual->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LessEqual, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLessEqual->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalAnd, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalAnd->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalOr, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalOr->name());
|
||||
});
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalXor, []() {
|
||||
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalXor->name());
|
||||
});
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,187 +18,39 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_LOGIC_CPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ArithmeticLogicCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ArithmeticLogicCpuKernelMod() = default;
|
||||
explicit ArithmeticLogicCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~ArithmeticLogicCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
const size_t kInputsNum = 2;
|
||||
const size_t kOutputsNum = 1;
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitComputeFunc();
|
||||
template <typename Op>
|
||||
void BinaryOp(const T *input1, const T *input2, bool *out, Op op);
|
||||
|
||||
void Less(const T *input1, const T *input2, bool *out);
|
||||
void Equal(const T *input1, const T *input2, bool *out);
|
||||
void NotEqual(const T *input1, const T *input2, bool *out);
|
||||
void Greater(const T *input1, const T *input2, bool *out);
|
||||
void GreaterEqual(const T *input1, const T *input2, bool *out);
|
||||
void LessEqual(const T *input1, const T *input2, bool *out);
|
||||
void LogicalAnd(const T *input1, const T *input2, bool *out);
|
||||
void LogicalOr(const T *input1, const T *input2, bool *out);
|
||||
void LogicalXor(const T *input1, const T *input2, bool *out);
|
||||
|
||||
using TypeComputeFunc = std::function<void(ArithmeticLogicCpuKernelMod *, const T *, const T *, bool *)>;
|
||||
TypeComputeFunc compute_func_{nullptr};
|
||||
size_t output_size_{1};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
|
||||
std::vector<size_t> input_shape1_;
|
||||
std::vector<size_t> input_shape2_;
|
||||
std::vector<size_t> input_element_num1_;
|
||||
std::vector<size_t> input_element_num2_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
std::shared_ptr<CpuKernelFunc> func_obj_;
|
||||
std::string kernel_type_{"Unknown"};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Less, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
NotEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Greater,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Greater,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
GreaterEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
LessEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
LessEqual,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
LogicalAnd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
LogicalXor, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticLogicCpuKernelMod, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -19,22 +19,75 @@
|
|||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "plugin/device/cpu/kernel/mkldnn/eltwise_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
|
||||
constexpr float kMaxNegSerialSize = 5000.0f;
|
||||
constexpr float kMaxSquareSerialSize = 5000.0f;
|
||||
constexpr size_t kInputsNum = 1;
|
||||
constexpr size_t kOutputsNum = 1;
|
||||
|
||||
constexpr auto kSquare = "Square";
|
||||
constexpr auto kNeg = "Neg";
|
||||
constexpr auto kSign = "Sign";
|
||||
constexpr auto kFloor = "Floor";
|
||||
constexpr auto kRint = "Rint";
|
||||
constexpr auto kRound = "Round";
|
||||
constexpr auto kReciprocal = "Reciprocal";
|
||||
constexpr auto kGeLU = "GeLU";
|
||||
constexpr auto kLogicalNot = "LogicalNot";
|
||||
constexpr auto kAsin = "Asin";
|
||||
constexpr auto kACos = "ACos";
|
||||
constexpr auto kAtan = "Atan";
|
||||
constexpr auto kSin = "Sin";
|
||||
constexpr auto kCos = "Cos";
|
||||
constexpr auto kTan = "Tan";
|
||||
constexpr auto kSinh = "Sinh";
|
||||
constexpr auto kCosh = "Cosh";
|
||||
constexpr auto kAsinh = "Asinh";
|
||||
constexpr auto kAcosh = "Acosh";
|
||||
constexpr auto kAtanh = "Atanh";
|
||||
constexpr auto kAbs = "Abs";
|
||||
constexpr auto kSqrt = "Sqrt";
|
||||
constexpr auto kRsqrt = "Rsqrt";
|
||||
|
||||
class ArithmeticSelfCpuKernelFunc : public CpuKernelFunc {
|
||||
public:
|
||||
ArithmeticSelfCpuKernelFunc() = default;
|
||||
~ArithmeticSelfCpuKernelFunc() override = default;
|
||||
|
||||
void InitFunc(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
void LaunchLogicalNot(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernelComplex(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
std::string kernel_name_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void Square(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Square(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = in[i] * in[i];
|
||||
|
@ -44,7 +97,7 @@ void Square(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t siz
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Sign(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Sign(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
if (in[i] < 0) {
|
||||
|
@ -60,7 +113,7 @@ void Sign(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Neg(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Neg(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = -in[i];
|
||||
|
@ -69,7 +122,7 @@ void Neg(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
ParallelLaunch(task, size, kMaxNegSerialSize);
|
||||
}
|
||||
|
||||
void LogicalNot(ArithmeticSelfCpuKernelMod *content, const bool *in, bool *out, size_t size) {
|
||||
void LogicalNot(ArithmeticSelfCpuKernelFunc *content, const bool *in, bool *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = !in[i];
|
||||
|
@ -79,7 +132,7 @@ void LogicalNot(ArithmeticSelfCpuKernelMod *content, const bool *in, bool *out,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Floor(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Floor(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(floor(in[i]));
|
||||
|
@ -89,7 +142,7 @@ void Floor(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Rint(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Rint(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(rint(in[i]));
|
||||
|
@ -99,7 +152,7 @@ void Rint(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Round(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Round(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(nearbyint(in[i]));
|
||||
|
@ -109,7 +162,7 @@ void Round(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Reciprocal(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Reciprocal(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(1.0 / in[i]);
|
||||
|
@ -119,7 +172,7 @@ void Reciprocal(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Gelu(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Gelu(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
auto factor_a = static_cast<T>(0.7978845608);
|
||||
auto factor_b = static_cast<T>(0.044715);
|
||||
|
@ -134,7 +187,7 @@ void Gelu(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Asin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Asin(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(asin(static_cast<double>(in[i])));
|
||||
|
@ -144,7 +197,7 @@ void Asin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ACos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void ACos(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(acos(static_cast<double>(in[i])));
|
||||
|
@ -154,7 +207,7 @@ void ACos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Atan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Atan(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(atan(static_cast<double>(in[i])));
|
||||
|
@ -164,7 +217,7 @@ void Atan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Sin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Sin(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(sin(static_cast<double>(in[i])));
|
||||
|
@ -174,7 +227,7 @@ void Sin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Cos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Cos(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(cos(static_cast<double>(in[i])));
|
||||
|
@ -184,7 +237,7 @@ void Cos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ComplexSin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void ComplexSin(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(sin(in[i]));
|
||||
|
@ -194,7 +247,7 @@ void ComplexSin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ComplexSinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void ComplexSinh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(sinh(in[i]));
|
||||
|
@ -204,7 +257,7 @@ void ComplexSinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ComplexCos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void ComplexCos(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(cos(in[i]));
|
||||
|
@ -214,7 +267,7 @@ void ComplexCos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ComplexCosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void ComplexCosh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(cosh(in[i]));
|
||||
|
@ -224,7 +277,7 @@ void ComplexCosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Tan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Tan(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(tan(static_cast<double>(in[i])));
|
||||
|
@ -234,7 +287,7 @@ void Tan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Sinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Sinh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(sinh(static_cast<double>(in[i])));
|
||||
|
@ -244,7 +297,7 @@ void Sinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Cosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Cosh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(cosh(static_cast<double>(in[i])));
|
||||
|
@ -254,7 +307,7 @@ void Cosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ComplexAsinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void ComplexAsinh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(asinh(in[i]));
|
||||
|
@ -264,7 +317,7 @@ void ComplexAsinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Asinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Asinh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(asinh(static_cast<double>(in[i])));
|
||||
|
@ -274,7 +327,7 @@ void Asinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ComplexAcosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void ComplexAcosh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(acosh(in[i]));
|
||||
|
@ -284,7 +337,7 @@ void ComplexAcosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Acosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Acosh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(acosh(static_cast<double>(in[i])));
|
||||
|
@ -294,7 +347,7 @@ void Acosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Atanh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Atanh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(atanh(static_cast<double>(in[i])));
|
||||
|
@ -304,7 +357,7 @@ void Atanh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Abs(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Abs(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = abs(in[i]);
|
||||
|
@ -314,7 +367,7 @@ void Abs(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Sqrt(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Sqrt(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(sqrt(in[i]));
|
||||
|
@ -324,7 +377,7 @@ void Sqrt(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void Rsqrt(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
|
||||
void Rsqrt(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(1) / sqrt(in[i]);
|
||||
|
@ -337,17 +390,41 @@ template <typename T>
|
|||
void Identity(const T *in, T *out, size_t size) {
|
||||
(void)std::copy(in, in + size, out);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ArithmeticSelfCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
template <typename T>
|
||||
bool IdentityCpuFunc(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||
Identity<T>(input, output, lens);
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::vector<std::pair<KernelAttr, LaunchFunc>> identity_kernel_attr_lists = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), IdentityCpuFunc<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), IdentityCpuFunc<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), IdentityCpuFunc<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), IdentityCpuFunc<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), IdentityCpuFunc<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), IdentityCpuFunc<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), IdentityCpuFunc<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), IdentityCpuFunc<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), IdentityCpuFunc<complex64>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), IdentityCpuFunc<complex128>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), IdentityCpuFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), IdentityCpuFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), IdentityCpuFunc<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), IdentityCpuFunc<bool>}};
|
||||
|
||||
void ArithmeticSelfCpuKernelFunc::InitFunc(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
bool ArithmeticSelfCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ArithmeticSelfCpuKernelFunc::RunFunc(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16) {
|
||||
|
@ -373,8 +450,8 @@ bool ArithmeticSelfCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &i
|
|||
return true;
|
||||
}
|
||||
|
||||
void ArithmeticSelfCpuKernelMod::LaunchLogicalNot(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
void ArithmeticSelfCpuKernelFunc::LaunchLogicalNot(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto *input = reinterpret_cast<bool *>(inputs[0]->addr);
|
||||
auto *output = reinterpret_cast<bool *>(outputs[0]->addr);
|
||||
size_t lens = outputs[0]->size / sizeof(bool);
|
||||
|
@ -382,13 +459,13 @@ void ArithmeticSelfCpuKernelMod::LaunchLogicalNot(const std::vector<AddressPtr>
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticSelfCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
void ArithmeticSelfCpuKernelFunc::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
const size_t lens = outputs[0]->size / sizeof(T);
|
||||
static const std::unordered_map<std::string,
|
||||
std::function<void(ArithmeticSelfCpuKernelMod *, const T *, T *, size_t)>>
|
||||
std::function<void(ArithmeticSelfCpuKernelFunc *, const T *, T *, size_t)>>
|
||||
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>},
|
||||
{prim::kPrimSign->name(), Sign<T>},
|
||||
{prim::kPrimNeg->name(), Neg<T>},
|
||||
|
@ -421,13 +498,13 @@ void ArithmeticSelfCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inp
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void ArithmeticSelfCpuKernelMod::LaunchKernelComplex(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
const size_t lens = outputs[0]->size / sizeof(T);
|
||||
static const std::unordered_map<std::string,
|
||||
std::function<void(ArithmeticSelfCpuKernelMod *, const T *, T *, size_t)>>
|
||||
std::function<void(ArithmeticSelfCpuKernelFunc *, const T *, T *, size_t)>>
|
||||
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>}, {prim::kPrimAcosh->name(), ComplexAcosh<T>},
|
||||
{prim::kPrimAsinh->name(), ComplexAsinh<T>}, {prim::kPrimNeg->name(), Neg<T>},
|
||||
{prim::kPrimSinh->name(), ComplexSinh<T>}, {prim::kPrimCosh->name(), ComplexCosh<T>},
|
||||
|
@ -435,22 +512,223 @@ void ArithmeticSelfCpuKernelMod::LaunchKernelComplex(const std::vector<AddressPt
|
|||
{prim::kPrimRsqrt->name(), Rsqrt<T>}};
|
||||
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
|
||||
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "ArithmeticSelfCpuKernelMod does not support " << kernel_name_;
|
||||
MS_LOG(EXCEPTION) << "ArithmeticSelfCpuKernelFunc does not support " << kernel_name_;
|
||||
}
|
||||
func_pair->second(this, input, output, lens);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool IdentityCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
|
||||
T *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||
Identity<T>(input, output, lens);
|
||||
return true;
|
||||
// MKLDNN Sqrt
|
||||
class SqrtMKLKernelFunc : public CpuKernelFunc, private EltWiseCpuKernelMod {
|
||||
public:
|
||||
SqrtMKLKernelFunc() : EltWiseCpuKernelMod(kSqrt) {}
|
||||
~SqrtMKLKernelFunc() override = default;
|
||||
|
||||
void InitFunc(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name != kSqrt) {
|
||||
MS_LOG(EXCEPTION) << "Should be " << kSqrt << ", but got " << kernel_name;
|
||||
}
|
||||
EltWiseCpuKernelMod::InitKernel(kernel_node);
|
||||
}
|
||||
|
||||
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return EltWiseCpuKernelMod::Launch(inputs, workspace, outputs);
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<CpuKernelFunc> CreateArithSelfFunc() { return std::make_shared<ArithmeticSelfCpuKernelFunc>(); }
|
||||
using ArithFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>>> arith_kernel_attr_list_map = {
|
||||
{kRsqrt,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kSquare,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kNeg,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kSign,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kFloor,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kRint,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kRound,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kReciprocal,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kGeLU, {{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc}}},
|
||||
{kLogicalNot, {{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateArithSelfFunc}}},
|
||||
{kAsin,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kACos,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kAtan,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kSin,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kCos,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kTan,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kSinh,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kCosh,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kAsinh,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
|
||||
{kAcosh,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kAtanh,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kAbs,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
|
||||
{kSqrt,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
[]() { return std::make_shared<SqrtMKLKernelFunc>(); }}}}};
|
||||
} // namespace
|
||||
|
||||
void ArithmeticSelfCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name_ != kernel_type_) {
|
||||
MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name_;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "ArithmeticSelf cpu does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
func_obj_ = arith_kernel_attr_list_map[kernel_name_][index].second();
|
||||
func_obj_->InitFunc(kernel_node);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ArithmeticSelfCpuKernelMod::GetOpSupport() {
|
||||
auto iter = arith_kernel_attr_list_map.find(kernel_type_);
|
||||
if (iter == arith_kernel_attr_list_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Arithmetic self cpu does not support " << kernel_type_;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ArithFuncCreator> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
void IdentityCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Identity cpu does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = identity_kernel_attr_lists[index].second;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> IdentityCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> kernel_attr_list;
|
||||
std::transform(identity_kernel_attr_lists.begin(), identity_kernel_attr_lists.end(),
|
||||
std::back_inserter(kernel_attr_list),
|
||||
[](const std::pair<KernelAttr, LaunchFunc> &pair) { return pair.first; });
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Rsqrt,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kRsqrt); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Square,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSquare); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Neg,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kNeg); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sign,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSign); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Floor,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kFloor); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Rint,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kRint); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Round,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kRound); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Reciprocal,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kReciprocal); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, GeLU,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kGeLU); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalNot,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kLogicalNot); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Asin,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAsin); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ACos,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kACos); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atan,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAtan); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sin,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSin); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Cos,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kCos); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Tan,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kTan); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sinh,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSinh); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Cosh,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kCosh); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Asinh,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAsinh); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Acosh,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAcosh); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atanh,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAtanh); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Abs,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAbs); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sqrt,
|
||||
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSqrt); });
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Identity, IdentityCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,211 +20,56 @@
|
|||
#include <complex>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using complex64 = std::complex<float>;
|
||||
using complex128 = std::complex<double>;
|
||||
constexpr auto kUnknown = "Unknown";
|
||||
|
||||
class ArithmeticSelfCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ArithmeticSelfCpuKernelMod() = default;
|
||||
explicit ArithmeticSelfCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~ArithmeticSelfCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return func_obj_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
void LaunchLogicalNot(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
template <typename T>
|
||||
void LaunchKernelComplex(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
std::shared_ptr<CpuKernelFunc> func_obj_;
|
||||
std::string kernel_type_{kUnknown};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class IdentityCpuKernelMod : public ArithmeticSelfCpuKernelMod {
|
||||
using LaunchFunc = std::function<bool(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
|
||||
class IdentityCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
IdentityCpuKernelMod() = default;
|
||||
~IdentityCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Round, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Round, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ArithmeticSelfCpuKernelMod);
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_);
|
||||
return kernel_func_(inputs, outputs);
|
||||
}
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
IdentityCpuKernelMod, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
IdentityCpuKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
IdentityCpuKernelMod, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
IdentityCpuKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
IdentityCpuKernelMod, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
IdentityCpuKernelMod, int16_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
IdentityCpuKernelMod, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
IdentityCpuKernelMod, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
IdentityCpuKernelMod, complex64);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
IdentityCpuKernelMod, complex128);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
IdentityCpuKernelMod, double);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
IdentityCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
IdentityCpuKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
IdentityCpuKernelMod, bool);
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
LaunchFunc kernel_func_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -96,5 +96,25 @@ bool AssignCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std
|
|||
ParallelLaunchAutoSearch(task, total_size, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> AssignCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
|
||||
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Assign, AssignCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -36,62 +36,14 @@ class AssignCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
size_t batch_size_{1};
|
||||
size_t input_x_dtype_size_{4};
|
||||
TypeId input_x_dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
AssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Assign,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
AssignCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -103,5 +103,7 @@ bool BiasAddCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BiasAdd, BiasAddCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -39,7 +39,6 @@ class BiasAddCpuKernelMod : public NativeCpuKernelMod {
|
|||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> bias_shape_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(BiasAdd, KernelAttr(), BiasAddCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_
|
||||
|
|
|
@ -70,5 +70,7 @@ bool BiasAddGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, cons
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BiasAddGrad, BiasAddGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <memory>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -37,7 +37,6 @@ class BiasAddGradCpuKernelMod : public NativeCpuKernelMod {
|
|||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr(), BiasAddGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_GRAD_CPU_KERNEL_H_
|
||||
|
|
|
@ -129,5 +129,25 @@ void BinaryCrossEntropyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< reduction;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> BinaryCrossEntropyCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
|
||||
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BinaryCrossEntropy, BinaryCrossEntropyCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -36,6 +36,9 @@ class BinaryCrossEntropyCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss) const;
|
||||
|
@ -47,28 +50,6 @@ class BinaryCrossEntropyCpuKernelMod : public NativeCpuKernelMod {
|
|||
ReductionType reduction_{kNone};
|
||||
bool weight_defined_{false}; // true: there are 3 inputs, false: there are 2 inputs(no [weight])
|
||||
};
|
||||
MS_REG_CPU_KERNEL(BinaryCrossEntropy,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
BinaryCrossEntropyCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(BinaryCrossEntropy,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
BinaryCrossEntropyCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(
|
||||
BinaryCrossEntropy,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BinaryCrossEntropyCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(
|
||||
BinaryCrossEntropy,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BinaryCrossEntropyCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H
|
||||
|
|
|
@ -108,5 +108,24 @@ void BinaryCrossEntropyGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node)
|
|||
<< reduction;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> BinaryCrossEntropyGradCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> kernel_attr_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)};
|
||||
|
||||
return kernel_attr_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BinaryCrossEntropyGrad, BinaryCrossEntropyGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/cpu/kernel/binary_cross_entropy_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -35,6 +35,9 @@ class BinaryCrossEntropyGradCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
|
||||
|
@ -44,36 +47,6 @@ class BinaryCrossEntropyGradCpuKernelMod : public NativeCpuKernelMod {
|
|||
ReductionType reduction_{kNone};
|
||||
bool weight_defined_{false}; // true: there are 4 inputs, false: there are 3 inputs(no [weight])
|
||||
};
|
||||
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
BinaryCrossEntropyGradCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
BinaryCrossEntropyGradCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
BinaryCrossEntropyGradCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
BinaryCrossEntropyGradCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/boundingbox_decode_cpu_kernel.h"
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void BoundingBoxDecodeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void BoundingBoxDecodeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -71,12 +71,13 @@ void BoundingBoxDecodeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the length of 'max_shape' should be at least 2, but got: " << max_shape_.size();
|
||||
}
|
||||
|
||||
InitTaskFunc(kernel_node);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool BoundingBoxDecodeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool BoundingBoxDecodeCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto anchor_box = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto deltas = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto bboxes = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
@ -156,5 +157,30 @@ bool BoundingBoxDecodeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressP
|
|||
ParallelLaunchAutoSearch(task, elem_num, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, BoundingBoxDecodeCpuKernelMod::BoundingBoxDecodeFunc>>
|
||||
BoundingBoxDecodeCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&BoundingBoxDecodeCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&BoundingBoxDecodeCpuKernelMod::LaunchKernel<float16>}};
|
||||
|
||||
void BoundingBoxDecodeCpuKernelMod::InitTaskFunc(const CNodePtr &kernel_node) {
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "BoundingBoxDecode does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> BoundingBoxDecodeCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, BoundingBoxDecodeFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BoundingBoxDecode, BoundingBoxDecodeCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,15 +19,15 @@
|
|||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t MIN_MAX_SHAPE_SIZE = 2;
|
||||
constexpr size_t INPUT_NUMS = 2;
|
||||
template <typename T>
|
||||
class BoundingBoxDecodeCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
BoundingBoxDecodeCpuKernelMod() = default;
|
||||
|
@ -35,26 +35,30 @@ class BoundingBoxDecodeCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitTaskFunc(const CNodePtr &kernel_node);
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using BoundingBoxDecodeFunc =
|
||||
std::function<bool(BoundingBoxDecodeCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, BoundingBoxDecodeFunc>> func_list_;
|
||||
BoundingBoxDecodeFunc kernel_func_;
|
||||
|
||||
std::vector<float> means_;
|
||||
std::vector<float> stds_;
|
||||
std::vector<int> max_shape_;
|
||||
float wh_ratio_clip_{0.016};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
BoundingBoxDecode,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BoundingBoxDecodeCpuKernelMod, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
BoundingBoxDecode,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BoundingBoxDecodeCpuKernelMod, float16);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -15,12 +15,35 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/boundingbox_encode_cpu_kernel.h"
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void BoundingBoxEncodeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
std::vector<std::pair<KernelAttr, BoundingBoxEncodeCpuKernelMod::BoundingBoxEncodeFunc>>
|
||||
BoundingBoxEncodeCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&BoundingBoxEncodeCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&BoundingBoxEncodeCpuKernelMod::LaunchKernel<float16>}};
|
||||
|
||||
void BoundingBoxEncodeCpuKernelMod::InitTaskFunc(const CNodePtr &kernel_node) {
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "BoundingBoxEncode does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> BoundingBoxEncodeCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, BoundingBoxEncodeFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
void BoundingBoxEncodeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -61,12 +84,12 @@ void BoundingBoxEncodeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
"but got the length of 'means': "
|
||||
<< means_.size() << ", and the length of 'stds': " << stds_.size();
|
||||
}
|
||||
InitTaskFunc(kernel_node);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool BoundingBoxEncodeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool BoundingBoxEncodeCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto anchor_box = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto groundtruth_box = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto deltas = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
@ -128,5 +151,7 @@ bool BoundingBoxEncodeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressP
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BoundingBoxEncode, BoundingBoxEncodeCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,14 +19,14 @@
|
|||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t INPUT_NUMS = 2;
|
||||
template <typename T>
|
||||
class BoundingBoxEncodeCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
BoundingBoxEncodeCpuKernelMod() = default;
|
||||
|
@ -34,24 +34,28 @@ class BoundingBoxEncodeCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using BoundingBoxEncodeFunc =
|
||||
std::function<bool(BoundingBoxEncodeCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, BoundingBoxEncodeFunc>> func_list_;
|
||||
BoundingBoxEncodeFunc kernel_func_;
|
||||
|
||||
void InitTaskFunc(const CNodePtr &kernel_node);
|
||||
std::vector<float> means_;
|
||||
std::vector<float> stds_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
BoundingBoxEncode,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BoundingBoxEncodeCpuKernelMod, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
BoundingBoxEncode,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
BoundingBoxEncodeCpuKernelMod, float16);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/broadcast_to_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -23,8 +25,43 @@ namespace {
|
|||
constexpr size_t kBroadcastToOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void BroadcastToCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, BroadcastToCpuKernelMod::BroadcastToFunc>>>
|
||||
BroadcastToCpuKernelMod::func_list_ = {
|
||||
{kBroadcastTo,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&BroadcastToCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&BroadcastToCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
&BroadcastToCpuKernelMod::LaunchKernel<bool>}}},
|
||||
{kDynamicBroadcastTo,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&BroadcastToCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&BroadcastToCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
&BroadcastToCpuKernelMod::LaunchKernel<bool>}}}};
|
||||
|
||||
void BroadcastToCpuKernelMod::InitTaskFunc(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name_ != kernel_type_) {
|
||||
MS_LOG(EXCEPTION) << "Suppose to be " << kernel_type_ << " but got " << kernel_name_;
|
||||
}
|
||||
|
||||
auto iter = func_list_.find(kernel_type_);
|
||||
if (iter == func_list_.end()) {
|
||||
MS_LOG(EXCEPTION) << "BroadcastTo cpu does not support " << kernel_type_;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "BroadcastTo does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[kernel_type_][index].second;
|
||||
}
|
||||
|
||||
void BroadcastToCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
@ -40,10 +77,11 @@ void BroadcastToCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
shape_info_.input_shape_size_ = SizeToInt(input_shape_size);
|
||||
shape_info_.output_shape_size_ = SizeToInt(output_shape_size);
|
||||
|
||||
InitTaskFunc(kernel_node);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BroadcastToCpuKernelMod<T>::CheckArgs() {
|
||||
void BroadcastToCpuKernelMod::CheckArgs() {
|
||||
size_t input_shape_size = input_shape_.size();
|
||||
size_t output_shape_size = output_shape_.size();
|
||||
if (output_shape_size < input_shape_size) {
|
||||
|
@ -71,8 +109,8 @@ void BroadcastToCpuKernelMod<T>::CheckArgs() {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool BroadcastToCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool BroadcastToCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_);
|
||||
CheckArgs();
|
||||
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
|
@ -99,5 +137,22 @@ bool BroadcastToCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, c
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> BroadcastToCpuKernelMod::GetOpSupport() {
|
||||
auto iter = func_list_.find(kernel_type_);
|
||||
if (iter == func_list_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Does not support " << kernel_type_ << "!";
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, BroadcastToFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, BroadcastTo,
|
||||
[]() { return std::make_shared<BroadcastToCpuKernelMod>(kBroadcastTo); });
|
||||
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, DynamicBroadcastTo,
|
||||
[]() { return std::make_shared<BroadcastToCpuKernelMod>(kDynamicBroadcastTo); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,50 +18,54 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BROADCAST_TO_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/cpu/kernel/nnacl/base/broadcast_to.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
constexpr auto kBroadcastTo = "BroadcastTo";
|
||||
constexpr auto kDynamicBroadcastTo = "DynamicBroadcastTo";
|
||||
constexpr auto kUnknown = "Unknown";
|
||||
class BroadcastToCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
BroadcastToCpuKernelMod() = default;
|
||||
explicit BroadcastToCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
|
||||
~BroadcastToCpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
void CheckArgs();
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using BroadcastToFunc =
|
||||
std::function<bool(BroadcastToCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::map<std::string, std::vector<std::pair<KernelAttr, BroadcastToFunc>>> func_list_;
|
||||
BroadcastToFunc kernel_func_;
|
||||
|
||||
void InitTaskFunc(const CNodePtr &kernel_node);
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
BroadcastShapeInfo shape_info_{};
|
||||
std::string kernel_type_{kUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BroadcastToCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastToCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastToCpuKernelMod, bool);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BroadcastToCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
BroadcastToCpuKernelMod, int);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastToCpuKernelMod, bool);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -19,18 +19,30 @@
|
|||
#include <cmath>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kCastInputsNum = 1;
|
||||
constexpr size_t kCastOutputsNum = 1;
|
||||
} // namespace
|
||||
template <typename S, typename T>
|
||||
class CastCpuKernelFunc : public CpuKernelFunc {
|
||||
public:
|
||||
CastCpuKernelFunc() = default;
|
||||
~CastCpuKernelFunc() override = default;
|
||||
|
||||
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
TypeId source_dtype_{kTypeUnknown};
|
||||
TypeId target_dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
template <typename S, typename T>
|
||||
void Cast(CastCpuKernelMod<S, T> *content, const S *in, T *out, size_t size) {
|
||||
void Cast(CastCpuKernelFunc<S, T> *content, const S *in, T *out, size_t size) {
|
||||
auto task = [&in, &out](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
out[i] = static_cast<T>(in[i]);
|
||||
|
@ -40,28 +52,186 @@ void Cast(CastCpuKernelMod<S, T> *content, const S *in, T *out, size_t size) {
|
|||
}
|
||||
|
||||
template <typename S, typename T>
|
||||
void CastCpuKernelMod<S, T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
source_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
target_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
template <typename S, typename T>
|
||||
bool CastCpuKernelMod<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCastInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCastOutputsNum, kernel_name_);
|
||||
if (outputs[0]->size == 0) {
|
||||
MS_LOG(WARNING) << "For '" << kernel_name_ << "', the memory size of output should be greater than 0, but got 0.";
|
||||
return true;
|
||||
}
|
||||
bool CastCpuKernelFunc<S, T>::RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
const auto *input = reinterpret_cast<S *>(inputs[0]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name();
|
||||
Cast<S, T>(this, input, output, outputs[0]->size / sizeof(T));
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename S, typename T>
|
||||
std::shared_ptr<CpuKernelFunc> CreateCastFunc() {
|
||||
return std::make_shared<CastCpuKernelFunc<S, T>>();
|
||||
}
|
||||
using CastCpuKernelFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
|
||||
|
||||
static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_lists = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint8_t, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint8_t, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint8_t, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<uint8_t, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<uint8_t, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<uint8_t, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<uint8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<uint8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint8_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint8_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint8_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint8_t, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint16_t, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint16_t, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint16_t, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<uint16_t, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<uint16_t, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<uint16_t, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<uint16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<uint16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint16_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint16_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint16_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint16_t, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint32_t, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint32_t, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint32_t, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<uint32_t, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<uint32_t, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<uint32_t, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<uint32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<uint32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint32_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint32_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint32_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint32_t, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint64_t, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint64_t, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint64_t, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<uint64_t, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<uint64_t, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<uint64_t, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<uint64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<uint64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint64_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint64_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint64_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint64_t, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int8_t, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int8_t, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int8_t, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<int8_t, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<int8_t, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<int8_t, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<int8_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<int8_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<int8_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int8_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int8_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int8_t, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int16_t, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int16_t, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int16_t, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<int16_t, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<int16_t, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<int16_t, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<int16_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<int16_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<int16_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int16_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int16_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int16_t, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int32_t, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int32_t, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int32_t, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<int32_t, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<int32_t, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<int32_t, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<int32_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<int32_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<int32_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int32_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int32_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int32_t, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int64_t, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int64_t, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int64_t, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<int64_t, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<int64_t, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<int64_t, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<int64_t, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<int64_t, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<int64_t, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int64_t, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int64_t, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int64_t, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<float16, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<float16, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<float16, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<float16, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<float16, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<float16, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<float16, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<float16, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<float16, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<float16, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<float16, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<float16, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<float, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<float, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<float, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<float, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<float, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<float, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<float, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<float, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<float, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<float, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<float, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<float, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<double, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<double, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<double, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<double, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<double, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<double, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<double, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<double, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<double, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<double, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<double, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<double, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<bool, uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<bool, uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<bool, uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<bool, uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<bool, int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<bool, int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<bool, int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<bool, int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<bool, float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<bool, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<bool, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateCastFunc<bool, bool>}};
|
||||
} // namespace
|
||||
|
||||
void CastCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
source_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
target_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(kernel_attr_lists.begin(), kernel_attr_lists.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CastCpuKernelFuncCreator> &pair) { return pair.first; });
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Cast does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = kernel_attr_lists[index].second();
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cast, CastCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,11 +22,10 @@
|
|||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename S, typename T>
|
||||
class CastCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CastCpuKernelMod() = default;
|
||||
|
@ -35,169 +34,24 @@ class CastCpuKernelMod : public NativeCpuKernelMod {
|
|||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
const size_t kCastInputsNum = 1;
|
||||
const size_t kCastOutputsNum = 1;
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCastInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCastOutputsNum, kernel_name_);
|
||||
if (outputs[0]->size == 0) {
|
||||
MS_LOG(WARNING) << "For '" << kernel_name_ << "', the memory size of output should be greater than 0, but got 0.";
|
||||
return true;
|
||||
}
|
||||
return kernel_func_->RunFunc(inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
TypeId source_dtype_{kTypeUnknown};
|
||||
TypeId target_dtype_{kTypeUnknown};
|
||||
|
||||
std::shared_ptr<CpuKernelFunc> kernel_func_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, bool);
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, uint16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, uint32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, uint64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, int8_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, int16_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, int32_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, int64_t);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, float16);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, float);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, double);
|
||||
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, bool);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -14,9 +14,10 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "plugin/device/cpu/kernel/check_valid_cpu_kernel.h"
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -29,20 +30,26 @@ constexpr size_t kIndex1 = 1;
|
|||
constexpr size_t kIndex2 = 2;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void CheckValidCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void CheckValidCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
anchor_box_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
img_metas_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "CheckValid does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CheckValidCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CheckParams(inputs, outputs);
|
||||
bool CheckValidCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CheckParams<T>(inputs, outputs);
|
||||
auto anchor_box = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto img_metas = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto output = reinterpret_cast<bool *>(outputs[0]->addr);
|
||||
|
@ -77,8 +84,8 @@ bool CheckValidCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &in
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckValidCpuKernelMod<T>::CheckParams(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
void CheckValidCpuKernelMod::CheckParams(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
// inputs: anchor_box, img_metas
|
||||
if (inputs.size() != kInputSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kInputSize << ", but got "
|
||||
|
@ -96,5 +103,24 @@ void CheckValidCpuKernelMod<T>::CheckParams(const std::vector<AddressPtr> &input
|
|||
<< Vector2Str(output_shape_) << ", the shape of 'img_metas': " << Vector2Str(img_metas_shape_);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, CheckValidCpuKernelMod::CheckValidFunc>> CheckValidCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
&CheckValidCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
&CheckValidCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
&CheckValidCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
&CheckValidCpuKernelMod::LaunchKernel<uint8_t>}};
|
||||
|
||||
std::vector<KernelAttr> CheckValidCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CheckValidFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CheckValid, CheckValidCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,14 +18,14 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHECK_VALID_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t COORDINATE = 4;
|
||||
template <typename T>
|
||||
class CheckValidCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CheckValidCpuKernelMod() = default;
|
||||
|
@ -33,33 +33,29 @@ class CheckValidCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void CheckParams(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using CheckValidFunc =
|
||||
std::function<bool(CheckValidCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, CheckValidFunc>> func_list_;
|
||||
CheckValidFunc kernel_func_;
|
||||
std::vector<size_t> anchor_box_shape_;
|
||||
std::vector<size_t> img_metas_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
CheckValid,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidCpuKernelMod, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
CheckValid,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidCpuKernelMod, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
CheckValid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidCpuKernelMod, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
CheckValid, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
CheckValidCpuKernelMod, uint8_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -14,15 +14,46 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/cpu/kernel/cholesky_inverse_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kDimNum = 2;
|
||||
}
|
||||
using CholeskyInverseFunc = std::function<bool(const CNodePtr &, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
template <typename T>
|
||||
void CholeskyInverseCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
bool CholeskyInverseKernelFunc(const CNodePtr &node_wpt, const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_x0 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_y = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto inputShape = AnfAlgo::GetInputDeviceShape(node_wpt, 0);
|
||||
int64_t n = SizeToLong(inputShape[0]);
|
||||
using MatrixXd = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
Eigen::Map<MatrixXd> A(input_x0, n, n);
|
||||
MatrixXd result;
|
||||
auto upper = common::AnfAlgo::GetNodeAttr<bool>(node_wpt, "upper");
|
||||
if (upper) {
|
||||
result = (A.transpose() * A).inverse();
|
||||
} else {
|
||||
result = (A * A.transpose()).inverse();
|
||||
}
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
for (int64_t j = 0; j < n; j++) {
|
||||
*(output_y + i * n + j) = result(i, j);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::vector<std::pair<KernelAttr, CholeskyInverseFunc>> kernel_attr_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CholeskyInverseKernelFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CholeskyInverseKernelFunc<double>}};
|
||||
} // namespace
|
||||
|
||||
void CholeskyInverseCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
node_wpt_ = kernel_node;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
|
@ -39,31 +70,24 @@ void CholeskyInverseCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< "while row is " << x_shape[x_shape.size() - kDimNum] << ", col is "
|
||||
<< x_shape[x_shape.size() - 1];
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Cholesky inverse valid cpu kernel does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = std::bind(kernel_attr_list[index].second, node_wpt_, std::placeholders::_1, std::placeholders::_2);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CholeskyInverseCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_x0 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_y = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto inputShape = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
|
||||
int64_t n = SizeToLong(inputShape[0]);
|
||||
typedef Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXd;
|
||||
Eigen::Map<MatrixXd> A(input_x0, n, n);
|
||||
MatrixXd result;
|
||||
auto upper = common::AnfAlgo::GetNodeAttr<bool>(node_wpt_, "upper");
|
||||
if (upper) {
|
||||
result = (A.transpose() * A).inverse();
|
||||
} else {
|
||||
result = (A * A.transpose()).inverse();
|
||||
}
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
for (int64_t j = 0; j < n; j++) {
|
||||
*(output_y + i * n + j) = result(i, j);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
std::vector<KernelAttr> CholeskyInverseCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(kernel_attr_list.begin(), kernel_attr_list.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CholeskyInverseFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CholeskyInverse, CholeskyInverseCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,13 +18,12 @@
|
|||
#include <Eigen/Dense>
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr size_t kInputNum = 1;
|
||||
constexpr size_t kOutputNum = 1;
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CholeskyInverseCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CholeskyInverseCpuKernelMod() = default;
|
||||
|
@ -32,16 +31,18 @@ class CholeskyInverseCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::function<bool(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)> kernel_func_;
|
||||
CNodePtr node_wpt_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL_T(CholeskyInverse, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CholeskyInverseCpuKernelMod, float)
|
||||
MS_REG_CPU_KERNEL_T(CholeskyInverse, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CholeskyInverseCpuKernelMod, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKYINVERSE_CPU_KERNEL_H_
|
||||
|
|
|
@ -153,5 +153,26 @@ void CoalesceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
|
|||
y_shape_addr[i] = x_shape_addr[i];
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CoalesceCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeInt64)};
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Coalesce, CoalesceCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -39,6 +39,9 @@ class CoalesceCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
@ -50,26 +53,6 @@ class CoalesceCpuKernelMod : public NativeCpuKernelMod {
|
|||
size_t jump = 0;
|
||||
CNodeWeakPtr node_wpt_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Coalesce,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
CoalesceCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(Coalesce,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
CoalesceCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_COALESCE_CPU_KERNEL_H_
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/concat_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -22,9 +24,7 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kConcatOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void ConcatCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void ConcatCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
cnode_ptr_ = kernel_node;
|
||||
|
@ -33,12 +33,19 @@ void ConcatCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToInt(input_1_shape.size());
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Concat does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ConcatCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ConcatCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto node_ = cnode_ptr_.lock();
|
||||
if (!node_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cnode_ptr_(kernel_node) is expired. Error no: " << node_;
|
||||
|
@ -88,5 +95,39 @@ bool ConcatCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs
|
|||
ParallelLaunchAutoSearch(task, before_axis, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, ConcatCpuKernelMod::ConcatFunc>> ConcatCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&ConcatCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&ConcatCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
&ConcatCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
&ConcatCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&ConcatCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
&ConcatCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
&ConcatCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&ConcatCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&ConcatCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
&ConcatCpuKernelMod::LaunchKernel<bool>}};
|
||||
|
||||
std::vector<KernelAttr> ConcatCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ConcatFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Concat, ConcatCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,13 +19,13 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ConcatCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ConcatCpuKernelMod() = default;
|
||||
|
@ -34,23 +34,22 @@ class ConcatCpuKernelMod : public NativeCpuKernelMod {
|
|||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using ConcatFunc = std::function<bool(ConcatCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, ConcatFunc>> func_list_;
|
||||
ConcatFunc kernel_func_;
|
||||
int axis_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, float)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, double)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, int32_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/concat_offset_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -23,9 +25,7 @@ namespace {
|
|||
constexpr size_t kConcatOffsetOutputNum = 1;
|
||||
constexpr size_t kConcatOffsetOutputShapeSize = 2;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void ConcatOffsetCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void ConcatOffsetCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
cnode_ptr_ = kernel_node;
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
@ -42,12 +42,18 @@ void ConcatOffsetCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< "', the 'axis' should be less than the dimension of 'input_x', but got 'axis': " << axis_
|
||||
<< ", and the dimension of 'input_x': " << input_1_shape.size();
|
||||
}
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Concat offset does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
template <typename T>
|
||||
bool ConcatOffsetCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOffsetOutputNum, kernel_name_);
|
||||
auto node_ = cnode_ptr_.lock();
|
||||
if (!node_) {
|
||||
|
@ -94,5 +100,37 @@ bool ConcatOffsetCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, ConcatOffsetCpuKernelMod::ConcatOffsetFunc>> ConcatOffsetCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
|
||||
&ConcatOffsetCpuKernelMod::LaunchKernel<bool>}}; // namespace kernel
|
||||
|
||||
std::vector<KernelAttr> ConcatOffsetCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, ConcatOffsetFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ConcatOffset, ConcatOffsetCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,12 +19,12 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ConcatOffsetCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
ConcatOffsetCpuKernelMod() = default;
|
||||
|
@ -33,22 +33,22 @@ class ConcatOffsetCpuKernelMod : public NativeCpuKernelMod {
|
|||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using ConcatOffsetFunc = std::function<bool(ConcatOffsetCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, ConcatOffsetFunc>> func_list_;
|
||||
ConcatOffsetFunc kernel_func_;
|
||||
size_t axis_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, int32_t)
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
#include "utils/profile.h"
|
||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||
|
@ -82,6 +84,125 @@ void NativeCpuKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
InitInputOutputSize(kernel_node);
|
||||
}
|
||||
|
||||
std::vector<TypeId> NativeCpuKernelMod::GetInputDtypes(const CNodePtr &kernel_node) {
|
||||
std::vector<TypeId> input_types;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
|
||||
input_types.emplace_back(dtype);
|
||||
}
|
||||
return input_types;
|
||||
}
|
||||
|
||||
std::vector<std::string> NativeCpuKernelMod::GetInputFormats(const CNodePtr &kernel_node) {
|
||||
std::vector<std::string> input_formats;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
input_formats.emplace_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
return input_formats;
|
||||
}
|
||||
|
||||
std::vector<TypeId> NativeCpuKernelMod::GetOutputDtypes(const CNodePtr &kernel_node) {
|
||||
std::vector<TypeId> output_types;
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
auto dtype = common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
|
||||
output_types.emplace_back(dtype);
|
||||
}
|
||||
return output_types;
|
||||
}
|
||||
|
||||
std::vector<std::string> NativeCpuKernelMod::GetOutputFormats(const CNodePtr &kernel_node) {
|
||||
std::vector<std::string> output_formats;
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
output_formats.emplace_back(kOpFormat_DEFAULT);
|
||||
}
|
||||
return output_formats;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> NativeCpuKernelMod::GetAllSupportedList(const std::string &kernel_name) {
|
||||
if (initialize_.count(kernel_name) == 0) {
|
||||
std::vector<KernelAttr> kernel_attrs;
|
||||
auto kernel_support = GetOpSupport();
|
||||
(void)kernel_attrs.insert(kernel_attrs.end(), kernel_support.begin(), kernel_support.end());
|
||||
if (kernel_attrs.empty()) {
|
||||
auto oplib_support = GetSupportFromOpLib(kernel_name);
|
||||
(void)kernel_attrs.insert(kernel_attrs.end(), oplib_support.begin(), oplib_support.end());
|
||||
}
|
||||
(void)support_map_.emplace(kernel_name, kernel_attrs);
|
||||
initialize_.insert(kernel_name);
|
||||
}
|
||||
|
||||
return support_map_[kernel_name];
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> NativeCpuKernelMod::GetSupportFromOpLib(const std::string &kernel_name) {
|
||||
static std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose",
|
||||
"Unpack", "AddN", "ConcatOffset", "DynamicStitch"};
|
||||
auto op_info = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
||||
if (op_info == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu. For more details, "
|
||||
<< "please refer to the list of supported cpu operations at https://www.mindspore.cn.";
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> support_kernel_attrs;
|
||||
auto inputs_ptr = op_info->inputs_ptr();
|
||||
auto outputs_ptr = op_info->outputs_ptr();
|
||||
if (outputs_ptr.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The output dimension of operator '" << kernel_name << "' should not be zero.";
|
||||
}
|
||||
|
||||
auto support_size = outputs_ptr[0]->dtypes().size();
|
||||
for (size_t i = 0; i < support_size; i++) {
|
||||
KernelAttr kernel_attr;
|
||||
for (size_t j = 0; j < inputs_ptr.size(); j++) {
|
||||
auto input_dtypes = inputs_ptr[j]->dtypes();
|
||||
auto input_formats = inputs_ptr[j]->formats();
|
||||
(void)kernel_attr.AddInputAttr(kernel::DtypeToTypeId(input_dtypes[i]), input_formats[i]);
|
||||
}
|
||||
for (size_t j = 0; j < outputs_ptr.size(); j++) {
|
||||
auto output_dtypes = outputs_ptr[j]->dtypes();
|
||||
auto output_formats = outputs_ptr[j]->formats();
|
||||
(void)kernel_attr.AddOutputAttr(kernel::DtypeToTypeId(output_dtypes[i]), output_formats[i]);
|
||||
}
|
||||
if (same_op_name.count(op_info->op_name()) != 0) {
|
||||
(void)kernel_attr.AddAllSameAttr(true);
|
||||
}
|
||||
support_kernel_attrs.push_back(kernel_attr);
|
||||
}
|
||||
|
||||
return support_kernel_attrs;
|
||||
}
|
||||
|
||||
std::map<std::string, std::vector<KernelAttr>> NativeCpuKernelMod::support_map_{};
|
||||
std::set<std::string> NativeCpuKernelMod::initialize_{};
|
||||
|
||||
void NativeCpuKernelMod::SetCpuRefMapToKernelInfo(const CNodePtr &apply_kernel) {
|
||||
auto kernel_attrs = GetOpSupport();
|
||||
if (kernel_attrs.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(apply_kernel);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, kernel_attrs);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << common::AnfAlgo::GetCNodeName(apply_kernel)
|
||||
<< " does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(apply_kernel);
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_Info);
|
||||
const auto &matched_kernel_attr = kernel_attrs[index];
|
||||
if (!matched_kernel_attr.GetOutInRefMap().empty()) {
|
||||
kernel_info->set_ref_map(matched_kernel_attr.GetOutInRefMap());
|
||||
}
|
||||
}
|
||||
|
||||
void CPUKernelUtils::ExpandDimsTo4(std::vector<size_t> *shape) {
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto len = shape->size();
|
||||
|
@ -435,5 +556,15 @@ void AxisIterator::Init(const std::vector<size_t> &input_shape, size_t axis) {
|
|||
inner_size_ *= input_shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
int Sign(float x) {
|
||||
if (x > 0) {
|
||||
return 1;
|
||||
}
|
||||
if (x < 0) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,8 +23,11 @@
|
|||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
#include "kernel/kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_mod.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
@ -145,8 +148,21 @@ class NativeCpuKernelMod : public CpuKernelMod {
|
|||
|
||||
ParallelSearchInfo parallel_search_info_;
|
||||
|
||||
static std::vector<KernelAttr> GetCpuSupportedList(const std::string &kernel_name) {
|
||||
auto temp_mod = kernel::Factory<NativeCpuKernelMod>::Instance().Create(kernel_name);
|
||||
if (temp_mod == nullptr) {
|
||||
MS_LOG(ERROR) << "Get cpu supported list failed!";
|
||||
return std::vector<KernelAttr>{};
|
||||
}
|
||||
return temp_mod->GetAllSupportedList(kernel_name);
|
||||
}
|
||||
|
||||
void SetCpuRefMapToKernelInfo(const CNodePtr &apply_kernel);
|
||||
|
||||
protected:
|
||||
virtual void InitInputOutputSize(const CNodePtr &kernel_node);
|
||||
virtual std::vector<KernelAttr> GetOpSupport() { return {}; }
|
||||
|
||||
CNodeWeakPtr cnode_ptr_;
|
||||
|
||||
template <typename T>
|
||||
|
@ -162,6 +178,28 @@ class NativeCpuKernelMod : public CpuKernelMod {
|
|||
|
||||
return reinterpret_cast<T *>(addr_list[index]->addr);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<KernelAttr> GetAllSupportedList(const std::string &kernel_name);
|
||||
std::vector<KernelAttr> GetSupportFromOpLib(const std::string &kernel_name);
|
||||
std::vector<TypeId> GetInputDtypes(const CNodePtr &kernel_node);
|
||||
std::vector<std::string> GetInputFormats(const CNodePtr &kernel_node);
|
||||
std::vector<TypeId> GetOutputDtypes(const CNodePtr &kernel_node);
|
||||
std::vector<std::string> GetOutputFormats(const CNodePtr &kernel_node);
|
||||
static std::map<std::string, std::vector<KernelAttr>> support_map_;
|
||||
static std::set<std::string> initialize_;
|
||||
};
|
||||
|
||||
class CpuKernelFunc {
|
||||
public:
|
||||
CpuKernelFunc() = default;
|
||||
virtual ~CpuKernelFunc() = default;
|
||||
virtual bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) = 0;
|
||||
virtual void InitFunc(const CNodePtr &kernel_node) {}
|
||||
virtual void InitInputOutputSize(const CNodePtr &kernel_node, std::vector<size_t> *input_size_list,
|
||||
std::vector<size_t> *output_size_list, std::vector<size_t> *workspace_size_list) {}
|
||||
ParallelSearchInfo parallel_search_info_;
|
||||
};
|
||||
|
||||
class CPUKernelUtils {
|
||||
|
@ -248,6 +286,8 @@ class AxisIterator {
|
|||
size_t inner_size_{0};
|
||||
size_t axis_offset_{0};
|
||||
};
|
||||
|
||||
int Sign(float x);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1,195 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
const std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose",
|
||||
"Unpack", "AddN", "ConcatOffset", "DynamicStitch"};
|
||||
}
|
||||
NativeCpuKernelModFactory &NativeCpuKernelModFactory::GetInstance() {
|
||||
static NativeCpuKernelModFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void NativeCpuKernelModFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr,
|
||||
NativeCpuKernelModCreator &&kernel_creator) {
|
||||
(void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator);
|
||||
}
|
||||
|
||||
std::shared_ptr<NativeCpuKernelMod> NativeCpuKernelModFactory::Create(const std::string &kernel_name,
|
||||
const CNodePtr &apply_kernel) {
|
||||
MS_EXCEPTION_IF_NULL(apply_kernel);
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_Info);
|
||||
std::pair<bool, size_t> ret_pair = CPUKernelAttrCheck(kernel_name, *kernel_build_Info);
|
||||
if (ret_pair.first) {
|
||||
SetRefMapToKernelInfo(kernel_name, ret_pair.second, kernel_info);
|
||||
return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void NativeCpuKernelModFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info,
|
||||
std::vector<KernelAttr> *kernel_attrs) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_attrs);
|
||||
MS_EXCEPTION_IF_NULL(op_info);
|
||||
auto inputs_ptr = op_info->inputs_ptr();
|
||||
auto outputs_ptr = op_info->outputs_ptr();
|
||||
if (outputs_ptr.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The output dimension of operator '" << op_info->op_name() << "' should not be zero.";
|
||||
}
|
||||
auto first_output_dtypes = outputs_ptr[0]->dtypes();
|
||||
|
||||
for (size_t i = 0; i < first_output_dtypes.size(); i++) {
|
||||
KernelAttr kernel_attr;
|
||||
for (size_t j = 0; j < inputs_ptr.size(); j++) {
|
||||
auto input_dtypes = inputs_ptr[j]->dtypes();
|
||||
auto input_formats = inputs_ptr[j]->formats();
|
||||
(void)kernel_attr.AddInputAttr(kernel::DtypeToTypeId(input_dtypes[i]), input_formats[i]);
|
||||
}
|
||||
for (size_t j = 0; j < outputs_ptr.size(); j++) {
|
||||
auto output_dtypes = outputs_ptr[j]->dtypes();
|
||||
auto output_formats = outputs_ptr[j]->formats();
|
||||
(void)kernel_attr.AddOutputAttr(kernel::DtypeToTypeId(output_dtypes[i]), output_formats[i]);
|
||||
}
|
||||
if (same_op_name.count(op_info->op_name()) != 0) {
|
||||
(void)kernel_attr.SetAllSameAttr(true);
|
||||
}
|
||||
(void)kernel_attrs->emplace_back(kernel_attr);
|
||||
}
|
||||
}
|
||||
|
||||
void NativeCpuKernelModFactory::UpdateKernelAttrs(const std::string &kernel_name,
|
||||
const std::vector<KernelAttr> &kernel_attrs) {
|
||||
size_t attr_size = kernel_attrs.size();
|
||||
std::vector<std::pair<KernelAttr, NativeCpuKernelModCreator>> attr_creators(attr_size);
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
if (iter == name_to_attr_creator_.end()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name
|
||||
<< ", only support these types: Concat, Pack, Stack, Split, Transpose, Unpack, AddN, "
|
||||
"ConcatOffset or DynamicStitch currently, but got "
|
||||
<< kernel_name;
|
||||
}
|
||||
|
||||
if (attr_size <= iter->second.size()) {
|
||||
for (size_t i = 0; i < attr_size; i++) {
|
||||
auto creator = name_to_attr_creator_.find(kernel_name)->second[i].second;
|
||||
attr_creators[i] = std::make_pair(kernel_attrs[i], creator);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(INFO) << "attr size is not equal creators size " << kernel_name << " attr_size = " << attr_size
|
||||
<< " creator_size = " << iter->second.size();
|
||||
auto single_creator = name_to_attr_creator_.find(kernel_name)->second[0].second;
|
||||
for (size_t i = 0; i < attr_size; i++) {
|
||||
attr_creators[i] = std::make_pair(kernel_attrs[i], single_creator);
|
||||
}
|
||||
}
|
||||
name_to_attr_creator_[kernel_name] = attr_creators;
|
||||
}
|
||||
|
||||
std::pair<bool, size_t> NativeCpuKernelModFactory::CPUKernelAttrCheck(const std::string &kernel_name,
|
||||
const KernelBuildInfo &kernel_info) {
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
if (iter == name_to_attr_creator_.end()) {
|
||||
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
|
||||
if (device::cpu::IsDynamicParamKernel(kernel_name)) {
|
||||
return std::make_pair(true, 0);
|
||||
}
|
||||
|
||||
auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
|
||||
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
|
||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
||||
if (op_info_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu. For more details, "
|
||||
<< "please refer to the list of supported cpu operations at https://www.mindspore.cn.";
|
||||
}
|
||||
kernel_attrs.clear();
|
||||
SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
||||
kernel::NativeCpuKernelModFactory::GetInstance().UpdateKernelAttrs(kernel_name, kernel_attrs);
|
||||
}
|
||||
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
||||
if (CPUKernelSingleAttrCheck(kernel_attrs[index], kernel_info)) {
|
||||
return std::make_pair(true, index);
|
||||
}
|
||||
}
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
|
||||
bool NativeCpuKernelModFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr,
|
||||
const KernelBuildInfo &kernel_info) const {
|
||||
for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) {
|
||||
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first;
|
||||
if (kernel_info.GetInputDeviceType(i) != dtype) {
|
||||
MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i)
|
||||
<< ", register type:" << dtype;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) {
|
||||
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first;
|
||||
if (kernel_info.GetOutputDeviceType(i) != dtype) {
|
||||
MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i)
|
||||
<< ", register type:" << dtype;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void NativeCpuKernelModFactory::SetRefMapToKernelInfo(const std::string &kernel_name, size_t index,
|
||||
device::KernelInfo *kernel_info) {
|
||||
const auto &kernel_attr = (name_to_attr_creator_.find(kernel_name)->second)[index].first;
|
||||
if (!kernel_attr.GetOutInRefMap().empty()) {
|
||||
kernel_info->set_ref_map(kernel_attr.GetOutInRefMap());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> NativeCpuKernelModFactory::GetSupportedKernelAttrList(const std::string &kernel_name) {
|
||||
std::vector<KernelAttr> result;
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
if (iter == name_to_attr_creator_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Not register CPU kernel of operator: " << kernel_name;
|
||||
}
|
||||
auto creators = iter->second;
|
||||
result.reserve(creators.size());
|
||||
for (size_t index = 0; index < creators.size(); ++index) {
|
||||
auto attr_creator = creators[index];
|
||||
(void)result.emplace_back(attr_creator.first);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool NativeCpuKernelModFactory::SearchRegisteredOp(const std::string &kernel_name) const {
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
return iter != name_to_attr_creator_.end();
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -1,95 +0,0 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "kernel/oplib/oplib.h"
|
||||
#include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using mindspore::device::cpu::KernelAttr;
|
||||
using NativeCpuKernelModCreator = std::function<std::shared_ptr<NativeCpuKernelMod>()>;
|
||||
|
||||
class NativeCpuKernelModFactory {
|
||||
public:
|
||||
static NativeCpuKernelModFactory &GetInstance();
|
||||
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr,
|
||||
NativeCpuKernelModCreator &&kernel_creator);
|
||||
std::shared_ptr<NativeCpuKernelMod> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
|
||||
void SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info, std::vector<KernelAttr> *kernel_attrs);
|
||||
void UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs);
|
||||
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
||||
bool SearchRegisteredOp(const std::string &kernel_name) const;
|
||||
|
||||
private:
|
||||
NativeCpuKernelModFactory() = default;
|
||||
~NativeCpuKernelModFactory() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(NativeCpuKernelModFactory)
|
||||
std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info);
|
||||
bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) const;
|
||||
|
||||
// Set output and input ref map to kernel info which will be used by graph compiler.
|
||||
void SetRefMapToKernelInfo(const std::string &kernel_name, size_t index, device::KernelInfo *kernel_info);
|
||||
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, NativeCpuKernelModCreator>>> name_to_attr_creator_;
|
||||
};
|
||||
|
||||
class NativeCpuKernelRegistrar {
|
||||
public:
|
||||
NativeCpuKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr,
|
||||
NativeCpuKernelModCreator &&kernel_creator) {
|
||||
NativeCpuKernelModFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator));
|
||||
}
|
||||
~NativeCpuKernelRegistrar() = default;
|
||||
};
|
||||
|
||||
#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS)
|
||||
#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS)
|
||||
#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \
|
||||
static_assert(std::is_base_of<NativeCpuKernelMod, OPCLASS>::value, " must be base of NativeCpuKernelMod"); \
|
||||
static const NativeCpuKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \
|
||||
[]() { return std::make_shared<OPCLASS>(); });
|
||||
|
||||
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T)
|
||||
#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T)
|
||||
#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \
|
||||
static_assert(std::is_base_of<NativeCpuKernelMod, OPCLASS<T>>::value, " must be base of NativeCpuKernelMod"); \
|
||||
static const NativeCpuKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \
|
||||
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T>>(); });
|
||||
|
||||
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
|
||||
MS_REG_CPU_KERNEL_T_S_(__COUNTER__, OPNAME, ATTR, OPCLASS, T, S)
|
||||
#define MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S) \
|
||||
_MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S)
|
||||
#define _MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S) \
|
||||
static_assert(std::is_base_of<NativeCpuKernelMod, OPCLASS<T, S>>::value, " must be base of NativeCpuKernelMod"); \
|
||||
static const NativeCpuKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_##S##_reg( \
|
||||
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T, S>>(); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_
|
|
@ -19,8 +19,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void CropAndResizeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void CropAndResizeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -98,12 +97,13 @@ void CropAndResizeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
method_ = BILINEAR_V2;
|
||||
}
|
||||
extrapolation_value_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "extrapolation_value");
|
||||
|
||||
InitFunc(kernel_node);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CropAndResizeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool CropAndResizeCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto *input_image = reinterpret_cast<T *>(inputs[IMAGE]->addr);
|
||||
auto *input_boxes = reinterpret_cast<float *>(inputs[BOXES]->addr);
|
||||
auto *input_box_index = reinterpret_cast<int *>(inputs[BOX_INDEX]->addr);
|
||||
|
@ -218,5 +218,153 @@ bool CropAndResizeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr>
|
|||
ParallelLaunchAutoSearch(task, IntToSize(output_size_), this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, CropAndResizeCpuKernelMod::CropAndResizeFunc>> CropAndResizeCpuKernelMod::func_list_ =
|
||||
{{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&CropAndResizeCpuKernelMod::LaunchKernel<uint16_t>}};
|
||||
|
||||
void CropAndResizeCpuKernelMod::InitFunc(const CNodePtr &kernel_node) {
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Concat does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CropAndResizeCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CropAndResizeFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CropAndResize, CropAndResizeCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,9 +20,10 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -40,7 +41,6 @@ constexpr size_t BOX_INDEX = 2;
|
|||
constexpr size_t CROP_SIZE = 3;
|
||||
constexpr size_t IMAGE_HEIGHT = 1;
|
||||
constexpr size_t IMAGE_WEIGHT = 2;
|
||||
template <typename T>
|
||||
class CropAndResizeCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CropAndResizeCpuKernelMod() = default;
|
||||
|
@ -49,9 +49,22 @@ class CropAndResizeCpuKernelMod : public NativeCpuKernelMod {
|
|||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitFunc(const CNodePtr &kernel_node);
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
using CropAndResizeFunc = std::function<bool(CropAndResizeCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, CropAndResizeFunc>> func_list_;
|
||||
CropAndResizeFunc kernel_func_;
|
||||
int method_{1};
|
||||
float extrapolation_value_{0.0};
|
||||
int output_size_{0};
|
||||
|
@ -61,168 +74,6 @@ class CropAndResizeCpuKernelMod : public NativeCpuKernelMod {
|
|||
int final_width_{0};
|
||||
int channel_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, double);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, double);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, int8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, int8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, int8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, int32_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, uint8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, uint8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, uint16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(CropAndResize,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CropAndResizeCpuKernelMod, uint16_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -170,5 +170,44 @@ bool CrossCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inpu
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CrossCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128)};
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cross, CrossCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <string>
|
||||
#include <complex>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -36,6 +36,9 @@ class CrossCpuKernelMod : public NativeCpuKernelMod {
|
|||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> input1_shape_;
|
||||
std::vector<size_t> input2_shape_;
|
||||
|
@ -43,67 +46,6 @@ class CrossCpuKernelMod : public NativeCpuKernelMod {
|
|||
int64_t dim_;
|
||||
TypeId input1_dtype_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Cross,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(Cross,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
CrossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(Cross,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
CrossCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -347,5 +347,26 @@ void CTCLossCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CTCLossCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)};
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CTCLoss, CTCLossCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include <limits>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -38,6 +38,9 @@ class CTCLossCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void GenLabelWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label,
|
||||
std::vector<std::vector<uint32_t>> *label_with_blank) const;
|
||||
|
@ -68,26 +71,6 @@ class CTCLossCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool ctc_merge_repeated_{false};
|
||||
bool ignore_longer_outputs_than_inputs_{false};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(CTCLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
CTCLossCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(CTCLoss,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
CTCLossCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_
|
||||
|
|
|
@ -15,22 +15,30 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/cummax_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
void CummaxCPUKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void CummaxCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output1_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
output2_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 1);
|
||||
dim_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "dim");
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Cummax does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CummaxCPUKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool CummaxCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto input_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output1_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto output2_data_addr = reinterpret_cast<int64_t *>(outputs[1]->addr);
|
||||
|
@ -98,5 +106,31 @@ bool CummaxCPUKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, CummaxCPUKernelMod::CummaxFunc>> CummaxCPUKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
&CummaxCPUKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
|
||||
&CummaxCPUKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&CummaxCPUKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&CummaxCPUKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
&CummaxCPUKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
&CummaxCPUKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
&CummaxCPUKernelMod::LaunchKernel<uint32_t>}};
|
||||
|
||||
std::vector<KernelAttr> CummaxCPUKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CummaxFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cummax, CummaxCPUKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,12 +19,12 @@
|
|||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CummaxCPUKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CummaxCPUKernelMod() = default;
|
||||
|
@ -32,38 +32,25 @@ class CummaxCPUKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using CummaxFunc = std::function<bool(CummaxCPUKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, CummaxFunc>> func_list_;
|
||||
CummaxFunc kernel_func_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output1_shape_;
|
||||
std::vector<size_t> output2_shape_;
|
||||
size_t dim_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, float16);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, int32_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, int64_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, int8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, uint8_t);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
CummaxCPUKernelMod, uint32_t);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_
|
||||
|
|
|
@ -240,5 +240,17 @@ void CumSumCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
|
|||
};
|
||||
ParallelLaunchAutoSearch(task, lens, this, ¶llel_search_info_);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CumSumCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32)},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8)},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8)}};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CumSum, CumSumCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -35,6 +35,9 @@ class CumSumCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void Reshape();
|
||||
|
||||
|
@ -81,17 +84,6 @@ class CumSumCpuKernelMod : public NativeCpuKernelMod {
|
|||
int axis_{0};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CumSumCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CumSumCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CumSumCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
CumSumCpuKernelMod);
|
||||
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
CumSumCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMSUM_CPU_KERNEL_H_
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include <functional>
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_common.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <functional>
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_common.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/device/cpu/kernel/custom/julia_api.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
|
|
|
@ -44,5 +44,13 @@ bool DebugCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> DebugCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Debug, DebugCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,9 +19,10 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -34,10 +35,10 @@ class DebugCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Debug, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
DebugCpuKernelMod);
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/depthtospace_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -23,19 +25,25 @@ constexpr size_t kDepthToSpaceInputsNum = 1;
|
|||
constexpr size_t kDepthToSpaceOutputsNum = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void DepthToSpaceCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void DepthToSpaceCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
block_size_ = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "block_size"));
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "DepthToSpace does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DepthToSpaceCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /* workspace */,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool DepthToSpaceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kDepthToSpaceInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDepthToSpaceOutputsNum, kernel_name_);
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
|
@ -76,5 +84,37 @@ bool DepthToSpaceCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &
|
|||
ParallelLaunchAutoSearch(task, size, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, DepthToSpaceCpuKernelMod::DepthToSpaceFunc>> DepthToSpaceCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&DepthToSpaceCpuKernelMod::LaunchKernel<uint64_t>}};
|
||||
|
||||
std::vector<KernelAttr> DepthToSpaceCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, DepthToSpaceFunc> &pair) { return pair.first; });
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DepthToSpace, DepthToSpaceCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,68 +18,39 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class DepthToSpaceCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
DepthToSpaceCpuKernelMod() = default;
|
||||
~DepthToSpaceCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
using DepthToSpaceFunc = std::function<bool(DepthToSpaceCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, DepthToSpaceFunc>> func_list_;
|
||||
DepthToSpaceFunc kernel_func_;
|
||||
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
size_t block_size_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DepthToSpaceCpuKernelMod, float);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
DepthToSpaceCpuKernelMod, float16);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DepthToSpace,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
DepthToSpaceCpuKernelMod, int8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DepthToSpace,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
DepthToSpaceCpuKernelMod, int16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DepthToSpace,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DepthToSpaceCpuKernelMod, int);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DepthToSpace,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
DepthToSpaceCpuKernelMod, int64_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DepthToSpace,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
DepthToSpaceCpuKernelMod, uint8_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DepthToSpace,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
DepthToSpaceCpuKernelMod, uint16_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DepthToSpace,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
DepthToSpaceCpuKernelMod, uint32_t);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DepthToSpace,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
DepthToSpaceCpuKernelMod, uint64_t);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_
|
||||
|
|
|
@ -75,5 +75,7 @@ void DropoutCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|||
output_addr[i] = mask_addr[i] * input_addr[i] * scale;
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Dropout, DropoutCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -46,8 +46,6 @@ class DropoutCpuKernelMod : public NativeCpuKernelMod {
|
|||
float keep_prob_{0.0};
|
||||
uint64_t tensor_size_{1};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Dropout, KernelAttr(), DropoutCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_
|
||||
|
|
|
@ -101,5 +101,7 @@ void DropoutGradBwdCpuKernelMod::DropoutBackwardKernel(const std::vector<Address
|
|||
DropoutGrad(input, mask, output, SizeToInt(num_count_), scale);
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DropoutGrad, DropoutGradBwdCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -44,8 +44,6 @@ class DropoutGradBwdCpuKernelMod : public NativeCpuKernelMod {
|
|||
size_t num_count_{1};
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(DropoutGrad, KernelAttr(), DropoutGradBwdCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -111,5 +111,17 @@ void DynamicAssignCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
|
|||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output should be a Parameter.";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> DynamicAssignCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
|
||||
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DynamicAssign, DynamicAssignCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include <unordered_map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -36,6 +36,9 @@ class DynamicAssignCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
@ -45,26 +48,6 @@ class DynamicAssignCpuKernelMod : public NativeCpuKernelMod {
|
|||
size_t input_x_dtype_size_{4};
|
||||
CNodeWeakPtr node_wpt_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
DynamicAssign,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicAssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
DynamicAssign,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicAssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
DynamicAssign,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicAssignCpuKernelMod);
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
DynamicAssign,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicAssignCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -22,9 +22,7 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kDynamicShapeOutputNum = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void TensorShapeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void TensorShapeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
cnode_ptr_ = kernel_node;
|
||||
|
@ -34,10 +32,9 @@ void TensorShapeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool TensorShapeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool TensorShapeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDynamicShapeOutputNum, kernel_name_);
|
||||
auto node_ = cnode_ptr_.lock();
|
||||
if (node_ == nullptr) {
|
||||
|
@ -60,5 +57,8 @@ bool TensorShapeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &i
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DynamicShape, TensorShapeCpuKernelMod);
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TensorShape, TensorShapeCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,11 +20,10 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class TensorShapeCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
TensorShapeCpuKernelMod() = default;
|
||||
|
@ -35,28 +34,6 @@ class TensorShapeCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int32_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, bool)
|
||||
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int32_t)
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "plugin/device/cpu/kernel/dynamic_stitch_cpu_kernel.h"
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -24,19 +25,13 @@ namespace kernel {
|
|||
namespace {
|
||||
constexpr size_t kDynamicStitchOutputNum = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void DynamicStitchCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
cnode_ptr_ = kernel_node;
|
||||
}
|
||||
|
||||
size_t GetShapeSize(const std::vector<size_t> &shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicStitchCpuKernelMod<T>::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool DynamicStitchCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDynamicStitchOutputNum, kernel_name_);
|
||||
auto node_ = cnode_ptr_.lock();
|
||||
int first_dim_size = 0;
|
||||
|
@ -84,14 +79,86 @@ void DynamicStitchCpuKernelMod<T>::LaunchKernel(const std::vector<kernel::Addres
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DynamicStitchCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
LaunchKernel(inputs, outputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, DynamicStitchCpuKernelMod::DynamicStitchFunc>> DynamicStitchCpuKernelMod::func_list_ =
|
||||
{{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt8)
|
||||
.AddOutputAttr(kNumberTypeInt8),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt16)
|
||||
.AddOutputAttr(kNumberTypeInt16),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt8)
|
||||
.AddOutputAttr(kNumberTypeUInt8),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt16)
|
||||
.AddOutputAttr(kNumberTypeUInt16),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt32)
|
||||
.AddOutputAttr(kNumberTypeUInt32),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeUInt64)
|
||||
.AddOutputAttr(kNumberTypeUInt64),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr()
|
||||
.AddAllSameAttr(true)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
&DynamicStitchCpuKernelMod::LaunchKernel<bool>}};
|
||||
|
||||
void DynamicStitchCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
cnode_ptr_ = kernel_node;
|
||||
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, DynamicStitchFunc> &pair) { return pair.first; });
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "DynamicStitch does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DynamicStitch, DynamicStitchCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,12 +19,13 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
|
||||
class DynamicStitchCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
DynamicStitchCpuKernelMod() = default;
|
||||
|
@ -32,26 +33,20 @@ class DynamicStitchCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
using DynamicStitchFunc = std::function<bool(DynamicStitchCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, DynamicStitchFunc>> func_list_;
|
||||
DynamicStitchFunc kernel_func_;
|
||||
size_t input_tuple_num_{1};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, int8_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, int16_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, int32_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, int64_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, uint8_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, uint16_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, uint32_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, uint64_t)
|
||||
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, bool)
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_STITCH_CPU_KERNEL_H_
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/eigen/cholesky_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "Eigen/Cholesky"
|
||||
|
@ -30,8 +32,7 @@ constexpr size_t kRowIndex = 2;
|
|||
constexpr size_t kColIndex = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void CholeskyCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
void CholeskyCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
if (shape.empty()) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " input or output shape is empty which is invalid.";
|
||||
}
|
||||
|
@ -52,8 +53,7 @@ void CholeskyCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, s
|
|||
outer_batch_ /= ((*row) * (*col));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CholeskyCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void CholeskyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
|
@ -72,11 +72,18 @@ void CholeskyCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
if (common::AnfAlgo::HasNodeAttr(LOWER, kernel_node)) {
|
||||
lower_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Cholesky does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool CholeskyCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *batch_input_value = reinterpret_cast<T *>(inputs[kInputIndex]->addr);
|
||||
T *batch_output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
|
||||
Eigen::LLT<Matrix<T, RowMajor>> llt;
|
||||
|
@ -102,5 +109,20 @@ bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, cons
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, CholeskyCpuKernelMod::CholeskyFunc>> CholeskyCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&CholeskyCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&CholeskyCpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> CholeskyCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CholeskyFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cholesky, CholeskyCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,12 +17,12 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CholeskyCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CholeskyCpuKernelMod() = default;
|
||||
|
@ -30,10 +30,25 @@ class CholeskyCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using CholeskyFunc =
|
||||
std::function<bool(CholeskyCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, CholeskyFunc>> func_list_;
|
||||
CholeskyFunc kernel_func_;
|
||||
|
||||
bool lower_{true};
|
||||
bool clean_{true};
|
||||
size_t outer_batch_{1};
|
||||
|
@ -43,12 +58,6 @@ class CholeskyCpuKernelMod : public NativeCpuKernelMod {
|
|||
size_t output_col_{1};
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CholeskyCpuKernelMod, float)
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CholeskyCpuKernelMod, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/eigen/cholesky_solve_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
|
||||
#include "Eigen/Cholesky"
|
||||
namespace mindspore {
|
||||
|
@ -29,8 +31,7 @@ constexpr size_t kRowIndex = 2;
|
|||
constexpr size_t kColIndex = 1;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void CholeskySolveCpuKernelMod<T>::InitRightMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
void CholeskySolveCpuKernelMod::InitRightMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
if (shape.empty()) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " input shape is empty which is invalid.";
|
||||
}
|
||||
|
@ -47,9 +48,8 @@ void CholeskySolveCpuKernelMod<T>::InitRightMatrixInfo(const std::vector<size_t>
|
|||
outer_batch_ /= ((*row) * (*col));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CholeskySolveCpuKernelMod<T>::InitLeftMatrixInfo(const std::vector<size_t> &shape, const bool is_rank_equal,
|
||||
size_t *row, size_t *col) {
|
||||
void CholeskySolveCpuKernelMod::InitLeftMatrixInfo(const std::vector<size_t> &shape, const bool is_rank_equal,
|
||||
size_t *row, size_t *col) {
|
||||
if (shape.empty()) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << " input or output shape is empty which is invalid.";
|
||||
}
|
||||
|
@ -62,8 +62,7 @@ void CholeskySolveCpuKernelMod<T>::InitLeftMatrixInfo(const std::vector<size_t>
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CholeskySolveCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void CholeskySolveCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
|
@ -85,11 +84,18 @@ void CholeskySolveCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
MS_LOG_EXCEPTION << kernel_name_ << " llt solve input a row is not match to b row: " << input_a_row_ << " vs "
|
||||
<< input_b_row_;
|
||||
}
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "CholeskySolve does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CholeskySolveCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool CholeskySolveCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *batch_input_value = reinterpret_cast<T *>(inputs[kInputAIndex]->addr);
|
||||
T *batch_input_b_value = reinterpret_cast<T *>(inputs[kInputBIndex]->addr);
|
||||
T *batch_output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
|
||||
|
@ -110,5 +116,20 @@ bool CholeskySolveCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, CholeskySolveCpuKernelMod::CholeskySolveFunc>> CholeskySolveCpuKernelMod::func_list_ =
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&CholeskySolveCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&CholeskySolveCpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> CholeskySolveCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CholeskySolveFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CholeskySolve, CholeskySolveCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,12 +17,12 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_SOLVE_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_SOLVE_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class CholeskySolveCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
CholeskySolveCpuKernelMod() = default;
|
||||
|
@ -30,12 +30,26 @@ class CholeskySolveCpuKernelMod : public NativeCpuKernelMod {
|
|||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitLeftMatrixInfo(const std::vector<size_t> &shape, const bool is_rank_equal, size_t *row, size_t *col);
|
||||
void InitRightMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using CholeskySolveFunc =
|
||||
std::function<bool(CholeskySolveCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, CholeskySolveFunc>> func_list_;
|
||||
CholeskySolveFunc kernel_func_;
|
||||
|
||||
size_t outer_batch_{1};
|
||||
size_t input_a_row_{1};
|
||||
size_t input_a_col_{1};
|
||||
|
@ -46,15 +60,6 @@ class CholeskySolveCpuKernelMod : public NativeCpuKernelMod {
|
|||
TypeId dtype_{kNumberTypeFloat32};
|
||||
bool lower_{false};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
CholeskySolve,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CholeskySolveCpuKernelMod, float)
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
CholeskySolve,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CholeskySolveCpuKernelMod, double)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_SOLVE_KERNEL_H_
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/eigen/eig_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "Eigen/Eigenvalues"
|
||||
|
@ -28,8 +30,7 @@ constexpr size_t kOutputsNumNV = 1;
|
|||
constexpr size_t kOutputsNumV = 2;
|
||||
} // namespace
|
||||
|
||||
template <typename T, typename C>
|
||||
void EigCpuKernelMod<T, C>::InitMatrixInfo(const std::vector<size_t> &shape) {
|
||||
void EigCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape) {
|
||||
if (shape.size() < kShape2dDims) {
|
||||
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', the rank of parameter 'a' must be at least 2, but got "
|
||||
<< shape.size() << " dimensions.";
|
||||
|
@ -48,8 +49,7 @@ void EigCpuKernelMod<T, C>::InitMatrixInfo(const std::vector<size_t> &shape) {
|
|||
batch_size_ /= (row_size_ * col_size_);
|
||||
}
|
||||
|
||||
template <typename T, typename C>
|
||||
void EigCpuKernelMod<T, C>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void EigCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
// If compute_v_ is true, then: w, v = Eig(a)
|
||||
// If compute_v_ is false, then: w = Eig(a)
|
||||
|
@ -63,10 +63,17 @@ void EigCpuKernelMod<T, C>::InitKernel(const CNodePtr &kernel_node) {
|
|||
CHECK_KERNEL_OUTPUTS_NUM(output_num, expect_output_num, kernel_name_);
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
InitMatrixInfo(input_shape);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Eig does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
}
|
||||
|
||||
template <typename T, typename C>
|
||||
bool EigCpuKernelMod<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
bool EigCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output_w_addr = reinterpret_cast<C *>(outputs[0]->addr);
|
||||
|
@ -92,5 +99,46 @@ bool EigCpuKernelMod<T, C>::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, EigCpuKernelMod::EigFunc>> EigCpuKernelMod::func_list_ = {
|
||||
// If compute_v is false.
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
|
||||
&EigCpuKernelMod::LaunchKernel<float, float_complex>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
|
||||
&EigCpuKernelMod::LaunchKernel<double, double_complex>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&EigCpuKernelMod::LaunchKernel<float_complex, float_complex>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&EigCpuKernelMod::LaunchKernel<double_complex, double_complex>},
|
||||
// If compute_v is true.
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&EigCpuKernelMod::LaunchKernel<float, float_complex>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&EigCpuKernelMod::LaunchKernel<double, double_complex>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&EigCpuKernelMod::LaunchKernel<float_complex, float_complex>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&EigCpuKernelMod::LaunchKernel<double_complex, double_complex>}};
|
||||
|
||||
std::vector<KernelAttr> EigCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, EigFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Eig, EigCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,8 +19,9 @@
|
|||
|
||||
#include <vector>
|
||||
#include <complex>
|
||||
#include <utility>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -28,60 +29,40 @@ namespace kernel {
|
|||
using float_complex = std::complex<float>;
|
||||
using double_complex = std::complex<double>;
|
||||
|
||||
/**
|
||||
* this is for Generic matrix eigenvalues and eigenvectors
|
||||
* @tparam T , input Type
|
||||
* @tparam C , output Type, complex
|
||||
*/
|
||||
template <typename T, typename C>
|
||||
class EigCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
EigCpuKernelMod() = default;
|
||||
~EigCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void InitMatrixInfo(const std::vector<size_t> &shape);
|
||||
|
||||
/**
|
||||
* this is for Generic matrix eigenvalues and eigenvectors
|
||||
* @tparam T , input Type
|
||||
* @tparam C , output Type, complex
|
||||
*/
|
||||
template <typename T, typename S>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using EigFunc = std::function<bool(EigCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, EigFunc>> func_list_;
|
||||
EigFunc kernel_func_;
|
||||
|
||||
bool compute_v_{true};
|
||||
size_t row_size_{1};
|
||||
size_t col_size_{1};
|
||||
size_t batch_size_{1};
|
||||
};
|
||||
|
||||
// If compute_v is false.
|
||||
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
|
||||
EigCpuKernelMod, float, float_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
|
||||
EigCpuKernelMod, double, double_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
EigCpuKernelMod, float_complex, float_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
EigCpuKernelMod, double_complex, double_complex);
|
||||
// If compute_v is true.
|
||||
MS_REG_CPU_KERNEL_T_S(
|
||||
Eig,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
EigCpuKernelMod, float, float_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
EigCpuKernelMod, double, double_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
EigCpuKernelMod, float_complex, float_complex);
|
||||
MS_REG_CPU_KERNEL_T_S(Eig,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
EigCpuKernelMod, double_complex, double_complex);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/eigen/eigh_cpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
|
||||
#include "Eigen/Eigenvalues"
|
||||
|
@ -27,8 +29,7 @@ constexpr size_t kInputsNum = 1;
|
|||
constexpr size_t kOutputsNum = 2;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void EighCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void EighCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
compute_eigen_vectors_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
|
||||
|
@ -44,6 +45,15 @@ void EighCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< A_shape[kDim1] << "].";
|
||||
}
|
||||
m_ = A_shape[kDim0];
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "Eigh does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = std::get<1>(func_list_[index]);
|
||||
const size_t kTwoIdx = 2;
|
||||
init_io_func_ = std::get<kTwoIdx>(func_list_[index]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -70,14 +80,14 @@ void SolveComplexMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<T>> *out
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void EighCpuKernelMod<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
void EighCpuKernelMod::InitIOFunc(const CNodePtr &kernel_node) {
|
||||
NativeCpuKernelMod::InitInputOutputSize(kernel_node);
|
||||
(void)workspace_size_list_.emplace_back(m_ * m_ * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool EighCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
bool EighCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
|
||||
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
// is the Matrix a symmetric matrix(true lower triangle, false upper triangle)
|
||||
|
@ -103,5 +113,39 @@ bool EighCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const st
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::tuple<KernelAttr, EighCpuKernelMod::EighFunc, EighCpuKernelMod::EighInitFunc>>
|
||||
EighCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EighCpuKernelMod::LaunchKernel<float>, &EighCpuKernelMod::InitIOFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&EighCpuKernelMod::LaunchKernel<double>, &EighCpuKernelMod::InitIOFunc<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
&EighCpuKernelMod::LaunchKernel<float_complex>, &EighCpuKernelMod::InitIOFunc<float_complex>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
&EighCpuKernelMod::LaunchKernel<double_complex>, &EighCpuKernelMod::InitIOFunc<double_complex>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&EighCpuKernelMod::LaunchKernel<float>, &EighCpuKernelMod::InitIOFunc<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&EighCpuKernelMod::LaunchKernel<double>, &EighCpuKernelMod::InitIOFunc<double>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
&EighCpuKernelMod::LaunchKernel<float_complex>, &EighCpuKernelMod::InitIOFunc<float_complex>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
&EighCpuKernelMod::LaunchKernel<double_complex>, &EighCpuKernelMod::InitIOFunc<double_complex>}};
|
||||
|
||||
std::vector<KernelAttr> EighCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::tuple<KernelAttr, EighFunc, EighInitFunc> &item) { return std::get<0>(item); });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Eigh, EighCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,67 +19,52 @@
|
|||
|
||||
#include <vector>
|
||||
#include <complex>
|
||||
#include <tuple>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
using float_complex = std::complex<float>;
|
||||
using double_complex = std::complex<double>;
|
||||
/**
|
||||
* this is for Symmetric matrix eigenvalues & eigenvectors, can decompress the lower/upper triangle matrix
|
||||
* @tparam T , input Type
|
||||
* @tparam C , output Type, complex
|
||||
*/
|
||||
template <typename T>
|
||||
class EighCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
EighCpuKernelMod() = default;
|
||||
~EighCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
/**
|
||||
* this is for Symmetric matrix eigenvalues & eigenvectors, can decompress the lower/upper triangle matrix
|
||||
* @tparam T , input Type
|
||||
*/
|
||||
template <typename T>
|
||||
void InitIOFunc(const CNodePtr &kernel_node);
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using EighFunc =
|
||||
std::function<bool(EighCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
using EighInitFunc = std::function<void(EighCpuKernelMod *, const CNodePtr &)>;
|
||||
static std::vector<std::tuple<KernelAttr, EighFunc, EighInitFunc>> func_list_;
|
||||
EighFunc kernel_func_;
|
||||
EighInitFunc init_io_func_;
|
||||
|
||||
size_t m_{1};
|
||||
bool compute_eigen_vectors_{false};
|
||||
bool lower_{true};
|
||||
TypeId dtype_{kNumberTypeFloat32};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EighCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
EighCpuKernelMod, double);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
EighCpuKernelMod, float_complex);
|
||||
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
EighCpuKernelMod, double_complex);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Eigh,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
EighCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Eigh,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
EighCpuKernelMod, double);
|
||||
|
||||
MS_REG_CPU_KERNEL_T(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
EighCpuKernelMod, float_complex);
|
||||
MS_REG_CPU_KERNEL_T(Eigh,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
EighCpuKernelMod, double_complex);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
#include <tuple>
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -35,8 +36,7 @@ constexpr size_t kColIndex = 1;
|
|||
constexpr int kZeroThreshold = INT32_MIN;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void LUCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
void LUCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
constexpr size_t lu_min_dim = 1;
|
||||
if (shape.size() <= lu_min_dim) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "shape is " << shape.size() << " which is invalid.";
|
||||
|
@ -50,8 +50,7 @@ void LUCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LUCpuKernelMod<T>::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
void LUCpuKernelMod::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
|
||||
constexpr size_t pivot_min_dim = 1;
|
||||
if (shape.size() < pivot_min_dim) {
|
||||
MS_LOG_EXCEPTION << kernel_name_ << "pivots shape is " << shape.size() << " which is invalid.";
|
||||
|
@ -64,8 +63,7 @@ void LUCpuKernelMod<T>::InitPivotVecInfo(const std::vector<size_t> &shape, size_
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LUCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
void LUCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
|
@ -81,10 +79,19 @@ void LUCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
InitMatrixInfo(permutation_shape, &permutation_row_, &permutation_col_);
|
||||
auto pivots_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex);
|
||||
InitPivotVecInfo(pivots_shape, &pivots_row_, &pivots_col_);
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(EXCEPTION) << "LU does not support this kernel data type: " << kernel_attr;
|
||||
}
|
||||
kernel_func_ = std::get<1>(func_list_[index]);
|
||||
const size_t kTwoIdx = 2;
|
||||
init_io_func_ = std::get<kTwoIdx>(func_list_[index]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LUCpuKernelMod<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
void LUCpuKernelMod::InitIOSize(const CNodePtr &kernel_node) {
|
||||
NativeCpuKernelMod::InitInputOutputSize(kernel_node);
|
||||
size_t lu_size = lu_col_ * sizeof(T);
|
||||
(void)workspace_size_list_.emplace_back(lu_size);
|
||||
|
@ -92,14 +99,14 @@ void LUCpuKernelMod<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
T LUCpuKernelMod<T>::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) {
|
||||
T LUCpuKernelMod::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) {
|
||||
const T *pered_lu_value = lu_value + per_value[i] * lu_col_ + j;
|
||||
return *pered_lu_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k,
|
||||
size_t rows) {
|
||||
bool LUCpuKernelMod::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k,
|
||||
size_t rows) {
|
||||
T max_major_value = static_cast<T>(kZeroThreshold);
|
||||
size_t max_major_index = 0;
|
||||
for (size_t i = k; i < rows; ++i) {
|
||||
|
@ -118,16 +125,16 @@ bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *pe
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void LUCpuKernelMod<T>::SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j,
|
||||
const T &value) {
|
||||
void LUCpuKernelMod::SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j,
|
||||
const T &value) {
|
||||
T *per_lu_value = lu_value + per_value[i] * lu_col_ + j;
|
||||
*per_lu_value = value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool LUCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
// input matrix of (m,n) PA = LU
|
||||
T *batch_a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
|
||||
T *batch_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
|
||||
|
@ -233,5 +240,28 @@ bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::tuple<KernelAttr, LUCpuKernelMod::LUFunc, LUCpuKernelMod::InitFunc>> LUCpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&LUCpuKernelMod::LaunchKernel<float>, &LUCpuKernelMod::InitIOSize<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&LUCpuKernelMod::LaunchKernel<double>, &LUCpuKernelMod::InitIOSize<double>}};
|
||||
|
||||
std::vector<KernelAttr> LUCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::tuple<KernelAttr, LUFunc, InitFunc> &tuple_item) { return std::get<0>(tuple_item); });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, LU, LUCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,28 +17,49 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class LUCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
LUCpuKernelMod() = default;
|
||||
~LUCpuKernelMod() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
|
||||
template <typename T>
|
||||
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows);
|
||||
template <typename T>
|
||||
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
|
||||
template <typename T>
|
||||
void InitIOSize(const CNodePtr &kernel_node);
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using LUFunc = std::function<bool(LUCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
using InitFunc = std::function<void(LUCpuKernelMod *, const CNodePtr &)>;
|
||||
static std::vector<std::tuple<KernelAttr, LUFunc, InitFunc>> func_list_;
|
||||
LUFunc kernel_func_;
|
||||
InitFunc init_io_func_;
|
||||
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
|
||||
|
||||
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
|
||||
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
|
||||
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows);
|
||||
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
|
||||
size_t batch_size_{1};
|
||||
size_t a_row_{1};
|
||||
size_t a_col_{1};
|
||||
|
@ -51,21 +72,6 @@ class LUCpuKernelMod : public NativeCpuKernelMod {
|
|||
TypeId dtype_{kNumberTypeFloat32};
|
||||
int *batch_pivots_{nullptr};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T(LU,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
LUCpuKernelMod, float);
|
||||
MS_REG_CPU_KERNEL_T(LU,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
LUCpuKernelMod, double);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue