refactor cpu factory

This commit is contained in:
TronZhang 2022-03-04 09:39:48 +08:00 committed by tronzhang
parent 8d2f098802
commit db2931516e
357 changed files with 9987 additions and 7765 deletions

View File

@ -24,9 +24,10 @@
#include "include/common/utils/context/graph_kernel_flags.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/factory/ms_factory.h"
#include "runtime/device/kernel_runtime.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/pass_manager.h"
@ -375,11 +376,12 @@ void CPUSession::BuildKernel(const KernelGraph *kernel_graph) {
continue;
}
std::shared_ptr<kernel::NativeCpuKernelMod> cpu_kernel_mod =
kernel::NativeCpuKernelModFactory::GetInstance().Create(kernel_name, kernel_node);
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Create(kernel_name);
if (cpu_kernel_mod == nullptr) {
KernelNotSupportException(kernel_node);
}
try {
cpu_kernel_mod->SetCpuRefMapToKernelInfo(kernel_node);
cpu_kernel_mod->Init(kernel_node);
} catch (std::exception &e) {
MS_LOG(EXCEPTION) << e.what() << trace::DumpSourceLines(kernel_node);

View File

@ -30,21 +30,20 @@ namespace server {
namespace kernel {
using mindspore::kernel::SGDCpuKernelMod;
template <typename T>
class SGDKernelMod : public SGDCpuKernelMod<T>, public OptimizerKernelMod {
class SGDKernelMod : public SGDCpuKernelMod, public OptimizerKernelMod {
public:
SGDKernelMod() = default;
~SGDKernelMod() override = default;
void InitKernel(const CNodePtr &cnode) override {
SGDCpuKernelMod<T>::InitKernel(cnode);
SGDCpuKernelMod::InitKernel(cnode);
InitServerKernelInputOutputSize(cnode);
GenerateReuseKernelNodeInfo();
}
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return SGDCpuKernelMod<T>::Launch(inputs, workspace, outputs);
return SGDCpuKernelMod::Launch(inputs, workspace, outputs);
}
void GenerateReuseKernelNodeInfo() override {

View File

@ -25,6 +25,7 @@
#include <vector>
#include <set>
#include <map>
#include <algorithm>
#include "utils/log_adapter.h"
#include "ir/dtype/type.h"

View File

@ -543,16 +543,6 @@ bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &
return true;
}
int Sign(float x) {
if (x > 0) {
return 1;
}
if (x < 0) {
return -1;
}
return 0;
}
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list) {
@ -1011,5 +1001,122 @@ size_t UnitSizeInBytes(const mindspore::TypeId &t) {
return bytes;
}
KernelAttr &KernelAttr::AddInputAttr(const TypeId &ms_type, const std::string &format) {
input_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr &KernelAttr::AddOutputAttr(const TypeId &ms_type, const std::string &format) {
output_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr &KernelAttr::AddAllSameAttr(const bool &all_same) {
all_same_ = all_same;
return *this;
}
KernelAttr &KernelAttr::AddOutInRef(size_t output_index, size_t input_index) {
out_in_ref_map_[output_index] = input_index;
return *this;
}
void KernelAttr::SetInputAttrList(const std::vector<DataType> &addr_list) {
input_type_.assign(addr_list.begin(), addr_list.end());
}
std::ostream &operator<<(std::ostream &os, KernelAttr kernel_attr) {
std::stringstream ss;
ss << "[Kernel Attr] all same: " << kernel_attr.GetAllSame();
size_t input_num = kernel_attr.GetInputSize();
if (input_num > 0) {
ss << ", input(";
for (size_t i = 0; i < input_num; ++i) {
ss << TypeIdLabel(kernel_attr.GetInputAttr(i).first);
if (i != input_num - 1) {
ss << ",";
}
}
ss << ") ";
}
size_t output_num = kernel_attr.GetOutputSize();
if (output_num > 0) {
ss << ", output(";
for (size_t i = 0; i < output_num; ++i) {
ss << TypeIdLabel(kernel_attr.GetOutputAttr(i).first);
if (i != output_num - 1) {
ss << ",";
}
}
ss << ").";
}
return os << ss.str();
}
std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
const std::vector<KernelAttr> &kernel_attr_list) {
// kernel_attr should not be all same. If so, then return false.
if (kernel_attr.GetAllSame()) {
return std::make_pair(false, 0);
}
auto input_num = kernel_attr.GetInputSize();
auto output_num = kernel_attr.GetOutputSize();
for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
const auto &cur_kernel_attr = kernel_attr_list[index];
auto cur_input_num = cur_kernel_attr.GetInputSize();
auto cur_output_num = cur_kernel_attr.GetOutputSize();
if (!cur_kernel_attr.GetAllSame() && (input_num != cur_input_num || output_num != cur_output_num)) {
continue;
}
bool mis_match = false;
for (size_t i = 0; i < input_num; ++i) {
auto dtype = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).first;
auto format = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).second;
if (kernel_attr.GetInputAttr(i).first != dtype || kernel_attr.GetInputAttr(i).second != format) {
mis_match = true;
break;
}
}
if (mis_match) {
continue;
}
for (size_t i = 0; i < output_num; ++i) {
auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).first;
auto format = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).second;
if (kernel_attr.GetOutputAttr(i).first != dtype || kernel_attr.GetOutputAttr(i).second != format) {
mis_match = true;
break;
}
}
if (!mis_match) {
return std::make_pair(true, index);
}
}
return std::make_pair(false, 0);
}
KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info) {
MS_EXCEPTION_IF_NULL(build_info);
KernelAttr kernel_attr;
for (size_t i = 0; i < build_info->GetInputNum(); i++) {
kernel_attr.AddInputAttr(build_info->GetInputDeviceType(i), build_info->GetInputFormat(i));
}
for (size_t j = 0; j < build_info->GetOutputNum(); j++) {
kernel_attr.AddOutputAttr(build_info->GetOutputDeviceType(j), build_info->GetOutputFormat(j));
}
return kernel_attr;
}
KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node) {
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
return GetKernelAttrFromBuildInfo(build_info);
}
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_
#ifndef MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_
#define MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_
#include <dirent.h>
#include <memory>
@ -23,6 +23,7 @@
#include <unordered_set>
#include <map>
#include <string>
#include <sstream>
#include <algorithm>
#include <vector>
#include <utility>
@ -145,7 +146,6 @@ void SaveJsonInfo(const std::string &json_name, const std::string &info, const s
std::string GetProcessor(const AnfNodePtr &anf_node);
Processor GetProcessor(const string &processor);
bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b);
int Sign(float x);
std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
const std::vector<AnfNodePtr> &input_list,
const std::vector<AnfNodePtr> &output_list);
@ -235,6 +235,41 @@ size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int
const std::vector<int64_t> &stop);
size_t UnitSizeInBytes(const mindspore::TypeId &t);
class KernelAttr {
public:
using DataType = std::pair<TypeId, std::string>;
KernelAttr() = default;
~KernelAttr() = default;
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT);
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT);
KernelAttr &AddAllSameAttr(const bool &all_same);
KernelAttr &AddOutInRef(size_t output_index, size_t input_index);
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
const bool &GetAllSame() const { return all_same_; }
size_t GetInputSize() const { return input_type_.size(); }
size_t GetOutputSize() const { return output_type_.size(); }
const OutputInputRefMap &GetOutInRefMap() const { return out_in_ref_map_; }
void SetInputAttrList(const std::vector<DataType> &addr_list);
private:
std::vector<DataType> input_type_;
std::vector<DataType> output_type_;
bool all_same_{false};
// The map between kernel's output and input ref relationship.
OutputInputRefMap out_in_ref_map_;
};
std::ostream &operator<<(std::ostream &os, KernelAttr kernel_attr);
std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr, const std::vector<KernelAttr> &attr_list);
KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info);
KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node);
#define CHECK_KERNEL_INPUTS_NUM(actual_inputs_num, expect_inputs_num, kernel_name) \
do { \
if ((actual_inputs_num) != (expect_inputs_num)) { \
@ -261,4 +296,4 @@ size_t UnitSizeInBytes(const mindspore::TypeId &t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_COMMON_UTILS_H_
#endif // MINDSPORE_CCSRC_KERNEL_COMMON_UTILS_H_

View File

@ -19,7 +19,8 @@
#include <memory>
#include <algorithm>
#include "kernel/common_utils.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "kernel/kernel_build_info.h"
#include "kernel/oplib/opinfo.h"
#include "kernel/oplib/oplib.h"
@ -83,7 +84,7 @@ void GetInputFormat(const CNodePtr &kernel_node, std::vector<std::string> *input
}
}
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr,
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const kernel::KernelAttr &kernel_attr,
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) {
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
@ -107,7 +108,7 @@ bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) {
return false;
}
int GetOutputDtypeMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &output_types) {
int GetOutputDtypeMatchedNum(const kernel::KernelAttr &kernel_attr, const std::vector<TypeId> &output_types) {
if (kernel_attr.GetOutputSize() != output_types.size()) {
MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize()
<< ", actual output num:" << output_types.size();
@ -126,7 +127,7 @@ int GetOutputDtypeMatchedNum(const KernelAttr &kernel_attr, const std::vector<Ty
return data_type_matched_num;
}
int GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, const std::vector<TypeId> &input_types,
int GetInputDtypeFormatMatchedNum(const kernel::KernelAttr &kernel_attr, const std::vector<TypeId> &input_types,
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
if (kernel_attr.GetInputSize() != input_types.size()) {
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
@ -145,7 +146,7 @@ int GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr, const std::vect
return data_type_matched_num;
}
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
void ExpandKernelAttr(const CNodePtr &kernel_node, kernel::KernelAttr *kernel_attr) {
MS_EXCEPTION_IF_NULL(kernel_attr);
size_t attr_num = kernel_attr->GetInputSize();
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
@ -169,7 +170,7 @@ void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
TypeId output_dtype = kernel_attr->GetOutputAttr(0).first;
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t i = 1; i < output_num; ++i) {
kernel_attr->AddOutputAttr(output_dtype);
(void)kernel_attr->AddOutputAttr(output_dtype);
}
}
@ -218,7 +219,7 @@ void KernelNotSupportException(const AnfNodePtr &kernel_node, const std::vector<
MS_EXCEPTION(TypeError) << operator_info.str() << trace::DumpSourceLines(kernel_node);
}
void UpdateDynamicKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {
void UpdateDynamicKernelBuildInfo(const CNodePtr &kernel_node) {
const std::string &op_name = common::AnfAlgo::GetCNodeName(kernel_node);
MS_LOG(INFO) << "Operator name: " << op_name;
// Set kernel build info
@ -232,35 +233,9 @@ void UpdateDynamicKernelBuildInfoAndAttrs(const CNodePtr &kernel_node) {
std::vector<std::string> output_formats;
GetOutputFormat(kernel_node, &output_formats);
SetKernelBuildInfo(input_formats, input_types, output_formats, output_types, kernel_node.get());
// Set kernel attrs
KernelAttr attr;
for (size_t i = 0; i < input_types.size(); i++) {
attr.AddInputAttr(input_types[i]);
}
for (size_t j = 0; j < output_types.size(); j++) {
attr.AddInputAttr(output_types[j]);
}
std::vector<KernelAttr> kernel_attrs = kernel::NativeCpuKernelModFactory::GetInstance().GetSupportedKernelAttrList(
common::AnfAlgo::GetCNodeName(kernel_node));
kernel_attrs.emplace_back(attr);
kernel::NativeCpuKernelModFactory::GetInstance().UpdateKernelAttrs(op_name, kernel_attrs);
return;
}
void AddKernelAttr(const CNodePtr &kernel_node, const KernelAttr &kernel_attr) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<KernelAttr> kernel_attrs = kernel::NativeCpuKernelModFactory::GetInstance().GetSupportedKernelAttrList(
common::AnfAlgo::GetCNodeName(kernel_node));
if (kernel_attrs.size() == 1 && kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
kernel_attrs.clear();
}
kernel_attrs.emplace_back(kernel_attr);
kernel::NativeCpuKernelModFactory::GetInstance().UpdateKernelAttrs(common::AnfAlgo::GetCNodeName(kernel_node),
kernel_attrs);
}
void UpdateCustomKernelBuildInfoAndAttrs(const CNodePtr &kernel_node, bool is_akg_op) {
void UpdateCustomKernelBuildInfo(const CNodePtr &kernel_node, bool is_akg_op) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
if (is_akg_op) {
@ -285,29 +260,11 @@ void UpdateCustomKernelBuildInfoAndAttrs(const CNodePtr &kernel_node, bool is_ak
builder->SetOutputsDeviceType(output_types);
builder->SetOutputsFormat(output_formats);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
if (!is_akg_op) {
// Update kernel attrs
KernelAttr attr;
if (input_types.size() != input_formats.size()) {
MS_LOG(EXCEPTION) << "input types size " << input_types.size() << " is not equal to input formats size "
<< input_formats.size() << " in op: " << common::AnfAlgo::GetCNodeName(kernel_node);
}
for (size_t i = 0; i < input_types.size(); ++i) {
attr.AddInputAttr(input_types[i], input_formats[i]);
}
if (output_types.size() != output_formats.size()) {
MS_LOG(EXCEPTION) << "output types size " << output_types.size() << " is not equal to output formats size "
<< output_formats.size() << " in op: " << common::AnfAlgo::GetCNodeName(kernel_node);
}
for (size_t i = 0; i < output_types.size(); ++i) {
attr.AddOutputAttr(output_types[i], output_formats[i]);
}
AddKernelAttr(kernel_node, attr);
}
}
KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<TypeId> &input_types,
const std::vector<TypeId> &output_types, const KernelAttr &kernel_attr) {
kernel::KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<TypeId> &input_types,
const std::vector<TypeId> &output_types,
const kernel::KernelAttr &kernel_attr) {
MS_EXCEPTION_IF_NULL(kernel_node);
// Only process Custom op
if (!IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
@ -320,17 +277,14 @@ KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<T
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetOutputSize() << ", actual input num:" << output_num;
return kernel_attr;
}
KernelAttr result;
bool has_none = false;
kernel::KernelAttr result;
// Fill inputs info.
for (size_t i = 0; i < input_num; ++i) {
auto type_format = kernel_attr.GetInputAttr(i);
if (type_format.first == TypeId::kMetaTypeNone) {
has_none = true;
type_format.first = input_types[i];
}
if (type_format.second.empty()) {
has_none = true;
type_format.second = kOpFormat_DEFAULT;
}
result.AddInputAttr(type_format.first, type_format.second);
@ -339,18 +293,13 @@ KernelAttr FillNoneInKernelAttr(const CNodePtr &kernel_node, const std::vector<T
for (size_t i = 0; i < output_num; ++i) {
auto type_format = kernel_attr.GetOutputAttr(i);
if (type_format.first == TypeId::kMetaTypeNone) {
has_none = true;
type_format.first = output_types[i];
}
if (type_format.second.empty()) {
has_none = true;
type_format.second = kOpFormat_DEFAULT;
}
result.AddOutputAttr(type_format.first, type_format.second);
}
if (has_none) {
AddKernelAttr(kernel_node, result);
}
return result;
}
} // namespace
@ -374,8 +323,8 @@ bool IsDynamicParamKernel(const std::string &op_name) {
return true;
}
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
const std::vector<KernelAttr> &kernel_attrs, const std::vector<TypeId> &input_types,
bool SelectKernel(const CNodePtr &kernel_node, kernel::KernelAttr *selected_kernel_attr,
const std::vector<kernel::KernelAttr> &kernel_attrs, const std::vector<TypeId> &input_types,
const std::vector<size_t> &input_not_cnode_indexes, const std::vector<TypeId> &output_types,
std::pair<bool, bool> *matched, bool strict) {
MS_EXCEPTION_IF_NULL(kernel_node);
@ -391,6 +340,7 @@ bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
MS_LOG(DEBUG) << "Output num is not equal!";
continue;
}
auto new_kernel_attr = FillNoneInKernelAttr(kernel_node, input_types, output_types, kernel_attr);
int input_dtype_matched_num =
GetInputDtypeFormatMatchedNum(new_kernel_attr, input_types, input_not_cnode_indexes, strict);
@ -412,17 +362,17 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
const std::string &op_name = common::AnfAlgo::GetCNodeName(kernel_node);
if (IsPrimitiveCNode(kernel_node, prim::kPrimCustom)) {
if (!kernel::NativeCpuKernelModFactory::GetInstance().SearchRegisteredOp(op_name)) {
if (!kernel::Factory<kernel::NativeCpuKernelMod>::Instance().IsRegistered(op_name)) {
auto tp = common::AnfAlgo::GetNodeAttr<std::string>(kernel_node, kAttrFuncType);
if (tp == kCustomTypePyfunc) {
kernel::NativeCpuKernelRegistrar(op_name, KernelAttr(),
[]() { return std::make_shared<kernel::PyFuncCpuKernelMod>(); });
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Register(
op_name, []() { return std::make_shared<kernel::PyFuncCpuKernelMod>(); });
} else if (tp == kCustomTypeAOT) {
kernel::NativeCpuKernelRegistrar(op_name, KernelAttr(),
[]() { return std::make_shared<kernel::CustomAOTCpuKernelMod>(); });
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Register(
op_name, []() { return std::make_shared<kernel::CustomAOTCpuKernelMod>(); });
} else if (tp == kCustomTypeJULIA) {
kernel::NativeCpuKernelRegistrar(op_name, KernelAttr(),
[]() { return std::make_shared<kernel::CustomJULIACpuKernelMod>(); });
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Register(
op_name, []() { return std::make_shared<kernel::CustomJULIACpuKernelMod>(); });
} else {
MS_LOG(EXCEPTION)
<< "Unsupported func type for Custom operator on CPU, it should be 'pyfunc' or 'aot' or 'julia', "
@ -434,13 +384,15 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
MS_LOG(WARNING) << "Not find operator information for Custom operator[" << op_name << "]. "
<< "Infer operator information from inputs. For more details, "
<< "please refer to 'mindspore.ops.Custom' at https://www.mindspore.cn.";
return UpdateCustomKernelBuildInfoAndAttrs(kernel_node, false);
UpdateCustomKernelBuildInfo(kernel_node, false);
return;
}
} else if (IsDynamicParamKernel(op_name)) {
// Select for dynamic kernel(both the number and data type are undetermined).
return UpdateDynamicKernelBuildInfoAndAttrs(kernel_node);
UpdateDynamicKernelBuildInfo(kernel_node);
return;
} else if (IsAKGSparseOP(kernel_node)) {
return UpdateCustomKernelBuildInfoAndAttrs(kernel_node, true);
return UpdateCustomKernelBuildInfo(kernel_node, true);
}
std::vector<std::string> input_formats;
@ -450,23 +402,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<TypeId> output_types;
std::vector<TypeId> selected_output_types;
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << op_name;
auto kernel_attrs = kernel::NativeCpuKernelModFactory::GetInstance().GetSupportedKernelAttrList(
common::AnfAlgo::GetCNodeName(kernel_node));
if (kernel_attrs.empty() || (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0)) {
MS_LOG(DEBUG) << "Operator[" << op_name << "] will get ops attr info.";
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Not find op[" << op_name << "] in cpu. For more details, "
<< "please refer to the list of supported cpu operations at https://www.mindspore.cn.";
}
kernel_attrs.clear();
kernel::NativeCpuKernelModFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
kernel::NativeCpuKernelModFactory::GetInstance().UpdateKernelAttrs(op_name, kernel_attrs);
}
GetInputDtypes(kernel_node, &input_types, &input_not_cnode_indexes);
GetOutputDtypes(kernel_node, &output_types);
KernelAttr selected_kernel_attr;
kernel::KernelAttr selected_kernel_attr;
std::pair<bool, bool> matched = std::make_pair(false, false);
auto kernel_attrs = kernel::NativeCpuKernelMod::GetCpuSupportedList(op_name);
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes,
output_types, &matched, true)) {
if (op_name == "Cast") {

View File

@ -31,52 +31,6 @@ namespace cpu {
using DataType = std::pair<TypeId, std::string>;
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
// Indicate whether the kernel input/output number are variable.
bool IsDynamicParamKernel(const std::string &op_name);
class KernelAttr {
public:
KernelAttr() : all_same_(0) {}
~KernelAttr() = default;
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
input_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
output_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr &SetAllSameAttr(bool all_same) {
all_same_ = all_same;
return *this;
}
KernelAttr &AddOutInRef(size_t output_index, size_t input_index) {
out_in_ref_map_[output_index] = input_index;
return *this;
}
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
bool GetAllSame() const { return all_same_; }
void SetInputAttrList(const std::vector<DataType> &addr_list) {
input_type_.assign(addr_list.begin(), addr_list.end());
}
size_t GetInputSize() const { return input_type_.size(); }
size_t GetOutputSize() const { return output_type_.size(); }
const OutputInputRefMap &GetOutInRefMap() const { return out_in_ref_map_; }
private:
std::vector<DataType> input_type_;
std::vector<DataType> output_type_;
// The map between kernel's output and input ref relationship.
OutputInputRefMap out_in_ref_map_;
bool all_same_;
};
} // namespace cpu
} // namespace device
} // namespace mindspore

View File

@ -19,7 +19,8 @@
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/hal/device/cpu_memory_manager.h"
#include "plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "kernel/kernel_build_info.h"
#include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
#include "utils/trace_base.h"
@ -220,11 +221,12 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
}
std::string kernel_name = common::AnfAlgo::GetCNodeName(node);
std::shared_ptr<kernel::NativeCpuKernelMod> cpu_kernel =
kernel::NativeCpuKernelModFactory::GetInstance().Create(kernel_name, node);
kernel::Factory<kernel::NativeCpuKernelMod>::Instance().Create(kernel_name);
if (!cpu_kernel) {
MS_LOG(EXCEPTION) << "Build cpu operator[" << node->fullname_with_scope() << "] failed";
}
cpu_kernel->SetCpuRefMapToKernelInfo(node);
cpu_kernel->Init(node);
AnfAlgo::SetKernelMod(cpu_kernel, node.get());
}

View File

@ -161,5 +161,7 @@ bool AdamCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, con
}
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Adam, AdamCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -37,15 +37,11 @@ class AdamCpuKernelMod : public NativeCpuKernelMod {
private:
template <typename T>
void LaunchAdam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void LaunchAdamNnacl(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
bool use_nesterov_{false};
TypeId dtype_{kTypeUnknown};
enum input_list_ { VAR, M, V, BETA1_POWER, BETA2_POWER, LR, BETA1, BETA2, EPSILON, GRAD };
};
MS_REG_CPU_KERNEL(Adam, KernelAttr(), AdamCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -149,5 +149,22 @@ bool AdamDeltaCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs
LaunchAdamDelta<float>(delta, m, v, lr, beta1, beta2, epsilon, grad, lens);
return true;
}
std::vector<KernelAttr> AdamDeltaCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdamNoUpdateParam, AdamDeltaCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -20,7 +20,7 @@
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -32,6 +32,9 @@ class AdamDeltaCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void CheckParams(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
@ -42,20 +45,6 @@ class AdamDeltaCpuKernelMod : public NativeCpuKernelMod {
size_t elem_num_{0};
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(AdamNoUpdateParam,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamDeltaCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -144,5 +144,7 @@ bool AdamWeightDecayCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &
}
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, AdamWeightDecay, AdamWeightDecayCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -41,8 +41,6 @@ class AdamWeightDecayCpuKernelMod : public NativeCpuKernelMod {
TypeId dtype_{kTypeUnknown};
enum input_list_ { VAR, M, V, LR, BETA1, BETA2, EPSILON, DECAY, GRAD };
};
MS_REG_CPU_KERNEL(AdamWeightDecay, KernelAttr(), AdamWeightDecayCpuKernelMod)
} // namespace kernel
} // namespace mindspore

View File

@ -45,5 +45,12 @@ bool AllGatherCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs
auto input_data_num = inputs[0]->size / sizeof(float);
return MPIAllGather(input_addr, output_addr, ranks_group_, input_data_num);
}
std::vector<KernelAttr> AllGatherCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, _HostAllGather, AllGatherCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -35,12 +35,12 @@ class AllGatherCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
std::vector<int> ranks_group_;
};
MS_REG_CPU_KERNEL(_HostAllGather, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AllGatherCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -116,5 +116,25 @@ void ApplyAdagradCpuKernelMod::LaunchApplyAdagrad(T *var, T *accum, const T *lr,
var[i] -= lr[0] * gradient[i] * (one / sqrt(accum[i] + eps));
}
}
std::vector<KernelAttr> ApplyAdagradCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ApplyAdagrad, ApplyAdagradCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -35,6 +35,9 @@ class ApplyAdagradCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void CheckParam(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
@ -47,26 +50,6 @@ class ApplyAdagradCpuKernelMod : public NativeCpuKernelMod {
bool update_slots_{true};
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(ApplyAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ApplyAdagradCpuKernelMod);
MS_REG_CPU_KERNEL(ApplyAdagrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
ApplyAdagradCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -57,5 +57,28 @@ bool ApplyMomentumCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &in
}
return true;
}
std::vector<KernelAttr> ApplyMomentumCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ApplyMomentum, ApplyMomentumCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -34,29 +34,10 @@ class ApplyMomentumCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0),
ApplyMomentumCpuKernelMod);
MS_REG_CPU_KERNEL(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutInRef(0, 0),
ApplyMomentumCpuKernelMod);
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel
} // namespace mindspore

View File

@ -57,37 +57,9 @@ bool check_validation(const std::vector<size_t> &shape, const size_t num_before_
} // namespace
template <typename T>
void ArgmaxCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t shape_len = shape_.size();
if (shape_len == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
}
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
axis += SizeToLong(shape_len);
if (axis < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
<< "], but got " << axis;
}
axis = axis % SizeToLong(shape_len);
num_before_axis_ = 1;
num_after_axis_ = 1;
for (size_t i = 0; i < shape_len; i++) {
if (SizeToLong(i) < axis) {
num_before_axis_ *= shape_[i];
} else if (SizeToLong(i) > axis) {
num_after_axis_ *= shape_[i];
}
}
dim_axis_ = shape_[LongToSize(axis)];
}
template <typename T>
bool ArgmaxCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ArgmaxCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (!check_validation<T>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
return false;
}
@ -112,5 +84,50 @@ bool ArgmaxCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs
}
return true;
}
void ArgmaxCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t shape_len = shape_.size();
if (shape_len == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
}
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
axis += SizeToLong(shape_len);
if (axis < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
<< "], but got " << axis;
}
axis = axis % SizeToLong(shape_len);
num_before_axis_ = 1;
num_after_axis_ = 1;
for (size_t i = 0; i < shape_len; i++) {
if (SizeToLong(i) < axis) {
num_before_axis_ *= shape_[i];
} else if (SizeToLong(i) > axis) {
num_after_axis_ *= shape_[i];
}
}
dim_axis_ = shape_[LongToSize(axis)];
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
if (build_info->GetInputNum() < 1) {
MS_LOG(EXCEPTION) << "Argmax input size should not less than 1!";
}
auto input_type_id = build_info->GetInputDeviceType(0);
switch (input_type_id) {
case kNumberTypeFloat32:
kernel_func_ = &ArgmaxCpuKernelMod::LaunchKernel<float>;
break;
case kNumberTypeFloat16:
kernel_func_ = &ArgmaxCpuKernelMod::LaunchKernel<float16>;
break;
default:
MS_LOG(EXCEPTION) << "Argmax kernel does not support " << TypeIdToString(input_type_id);
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Argmax, ArgmaxCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,11 +21,10 @@
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class ArgmaxCpuKernelMod : public NativeCpuKernelMod {
public:
ArgmaxCpuKernelMod() = default;
@ -34,17 +33,24 @@ class ArgmaxCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using ArgmaxFunc =
std::function<bool(ArgmaxCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
ArgmaxFunc kernel_func_;
std::vector<size_t> shape_;
size_t num_before_axis_{0};
size_t num_after_axis_{0};
size_t dim_axis_{0};
};
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(Argmax, KernelAttr(), ArgmaxCpuKernelMod, float16);
} // namespace kernel
} // namespace mindspore

View File

@ -61,34 +61,9 @@ bool check_validation(const std::vector<size_t> &shape, const size_t num_before_
} // namespace
template <typename T>
void ArgMaxWithValueCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t shape_len = shape_.size();
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
axis += SizeToLong(shape_len);
if (axis < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
<< "], but got " << axis;
}
axis = axis % SizeToLong(shape_len);
num_before_axis_ = 1;
num_after_axis_ = 1;
for (size_t i = 0; i < shape_len; i++) {
if (SizeToLong(i) < axis) {
num_before_axis_ *= shape_[i];
} else if (SizeToLong(i) > axis) {
num_after_axis_ *= shape_[i];
}
}
dim_axis_ = shape_[LongToSize(axis)];
}
template <typename T>
bool ArgMaxWithValueCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ArgMaxWithValueCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (!check_validation<T>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
return false;
}
@ -116,5 +91,47 @@ bool ArgMaxWithValueCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr
}
return true;
}
void ArgMaxWithValueCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t shape_len = shape_.size();
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
axis += SizeToLong(shape_len);
if (axis < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
<< "], but got " << axis;
}
axis = axis % SizeToLong(shape_len);
num_before_axis_ = 1;
num_after_axis_ = 1;
for (size_t i = 0; i < shape_len; i++) {
if (SizeToLong(i) < axis) {
num_before_axis_ *= shape_[i];
} else if (SizeToLong(i) > axis) {
num_after_axis_ *= shape_[i];
}
}
dim_axis_ = shape_[LongToSize(axis)];
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
if (build_info->GetInputNum() < 1) {
MS_LOG(EXCEPTION) << "Argmax input size should not less than 1!";
}
auto input_type_id = build_info->GetInputDeviceType(0);
switch (input_type_id) {
case kNumberTypeFloat32:
kernel_func_ = &ArgMaxWithValueCpuKernelMod::LaunchKernel<float>;
break;
case kNumberTypeFloat16:
kernel_func_ = &ArgMaxWithValueCpuKernelMod::LaunchKernel<float16>;
break;
default:
MS_LOG(EXCEPTION) << "Argmax kernel does not support " << TypeIdToString(input_type_id);
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ArgMaxWithValue, ArgMaxWithValueCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -23,11 +23,10 @@
#include <algorithm>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class ArgMaxWithValueCpuKernelMod : public NativeCpuKernelMod {
public:
ArgMaxWithValueCpuKernelMod() = default;
@ -35,18 +34,25 @@ class ArgMaxWithValueCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using ArgMaxWithValueFunc =
std::function<bool(ArgMaxWithValueCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
ArgMaxWithValueFunc kernel_func_;
std::vector<size_t> shape_;
size_t num_before_axis_{0};
size_t num_after_axis_{0};
size_t dim_axis_{0};
};
MS_REG_CPU_KERNEL_T(ArgMaxWithValue, KernelAttr(), ArgMaxWithValueCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(ArgMaxWithValue, KernelAttr(), ArgMaxWithValueCpuKernelMod, float16);
} // namespace kernel
} // namespace mindspore

View File

@ -62,37 +62,9 @@ bool check_validation(const std::vector<size_t> &shape, const size_t num_before_
} // namespace
template <typename T>
void ArgMinWithValueCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t shape_len = shape_.size();
if (shape_len == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
}
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
axis += static_cast<int64_t>(shape_len);
if (axis < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
<< "], but got " << axis;
}
axis = axis % static_cast<int64_t>(shape_len);
num_before_axis_ = 1;
num_after_axis_ = 1;
for (size_t i = 0; i < shape_len; i++) {
if (static_cast<int64_t>(i) < axis) {
num_before_axis_ *= shape_[i];
} else if (static_cast<int64_t>(i) > axis) {
num_after_axis_ *= shape_[i];
}
}
dim_axis_ = shape_[axis];
}
template <typename T>
bool ArgMinWithValueCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ArgMinWithValueCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (!check_validation<T>(shape_, num_before_axis_, num_after_axis_, inputs, outputs)) {
return false;
}
@ -119,5 +91,50 @@ bool ArgMinWithValueCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr
}
return true;
}
void ArgMinWithValueCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t shape_len = shape_.size();
if (shape_len == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'input_x' should be at least 1, but got 0.";
}
int64_t axis = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
axis += static_cast<int64_t>(shape_len);
if (axis < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in range [-1, " << (shape_len - 1)
<< "], but got " << axis;
}
axis = axis % static_cast<int64_t>(shape_len);
num_before_axis_ = 1;
num_after_axis_ = 1;
for (size_t i = 0; i < shape_len; i++) {
if (static_cast<int64_t>(i) < axis) {
num_before_axis_ *= shape_[i];
} else if (static_cast<int64_t>(i) > axis) {
num_after_axis_ *= shape_[i];
}
}
dim_axis_ = shape_[axis];
auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
if (build_info->GetInputNum() < 1) {
MS_LOG(EXCEPTION) << "Argmax input size should not less than 1!";
}
auto input_type_id = build_info->GetInputDeviceType(0);
switch (input_type_id) {
case kNumberTypeFloat32:
kernel_func_ = &ArgMinWithValueCpuKernelMod::LaunchKernel<float>;
break;
case kNumberTypeFloat16:
kernel_func_ = &ArgMinWithValueCpuKernelMod::LaunchKernel<float16>;
break;
default:
MS_LOG(EXCEPTION) << "Argmax kernel does not support " << TypeIdToString(input_type_id);
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ArgMinWithValue, ArgMinWithValueCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -23,11 +23,10 @@
#include <algorithm>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class ArgMinWithValueCpuKernelMod : public NativeCpuKernelMod {
public:
ArgMinWithValueCpuKernelMod() = default;
@ -36,17 +35,24 @@ class ArgMinWithValueCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using ArgMinWithValueFunc =
std::function<bool(ArgMinWithValueCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
ArgMinWithValueFunc kernel_func_;
std::vector<size_t> shape_;
size_t num_before_axis_{0};
size_t num_after_axis_{0};
size_t dim_axis_{0};
};
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(ArgMinWithValue, KernelAttr(), ArgMinWithValueCpuKernelMod, float16);
} // namespace kernel
} // namespace mindspore

View File

@ -20,6 +20,10 @@
#include <string>
#include <unordered_map>
#include <limits>
#include <map>
#include <algorithm>
#include <utility>
#include <memory>
#include "plugin/device/cpu/kernel/nnacl/fp32/power_fp32.h"
#include "plugin/device/cpu/kernel/nnacl/fp32/sub_fp32.h"
@ -29,11 +33,14 @@
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputsNum = 2;
constexpr size_t kOutputsNum = 1;
constexpr float kMaxSubSerialSize = 10000.0;
constexpr float kMaxPowSerialSize = 700.0;
constexpr auto kSub = "Sub";
constexpr auto kMul = "Mul";
constexpr auto kRealDiv = "RealDiv";
constexpr auto kAssignAdd = "AssignAdd";
template <typename T>
void ElementRealDiv(const T *input1, const T *input2, T *out, size_t size, size_t delta_1, size_t delta_2) {
size_t idx_1 = 0;
@ -59,10 +66,123 @@ void ElementRealDiv(const T *input1, const T *input2, T *out, size_t size, size_
out[i] = dividend / divisor;
}
}
} // namespace
template <typename T>
void ArithmeticCpuKernelMod<T>::AssignAdd(T *input1, const T *input2, T *out) {
class ArithmeticCpuTypeFunc : public CpuKernelFunc {
public:
~ArithmeticCpuTypeFunc() override = default;
explicit ArithmeticCpuTypeFunc(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape_.size() == 0) {
(void)output_shape_.insert(output_shape_.begin(), 1);
}
output_size_ = 1;
for (size_t i = 0; i < output_shape_.size(); ++i) {
output_size_ *= output_shape_[i];
}
op_para_.in_elements_num0_ = 1;
for (size_t i = 0; i < input_shape1_.size(); ++i) {
op_para_.in_elements_num0_ *= input_shape1_[i];
}
op_para_.in_elements_num1_ = 1;
for (size_t i = 0; i < input_shape2_.size(); ++i) {
op_para_.in_elements_num1_ *= input_shape2_[i];
}
size_t l = input_shape1_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
(void)input_shape1_.insert(input_shape1_.begin(), 1);
}
l = input_shape2_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
(void)input_shape2_.insert(input_shape2_.begin(), 1);
}
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
InitComputeFunc();
}
void Sub(const T *input1, const T *input2, T *out);
void Add(const T *input1, const T *input2, T *out);
void Mul(const T *input1, const T *input2, T *out);
void RealDiv(const T *input1, const T *input2, T *out);
void Div(const T *input1, const T *input2, T *out);
void FloorDiv(const T *input1, const T *input2, T *out);
void Mod(const T *input1, const T *input2, T *out);
void FloorMod(const T *input1, const T *input2, T *out);
void Pow(const T *input1, const T *input2, T *out);
void AssignAdd(T *input1, const T *input2, T *out);
void Atan2(const T *input1, const T *input2, T *out);
void SquaredDifference(const T *input1, const T *input2, T *out);
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
if (output_size_ == 0) {
MS_LOG(WARNING) << kernel_name_ << " output shape contain 0, output_shape: " << output_shape_;
return true;
}
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
AssignAdd(input1, input2, output);
} else {
compute_func_(this, input1, input2, output);
}
return true;
}
private:
void InitComputeFunc() {
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
return;
}
static const std::unordered_map<std::string, TypeComputeFunc> arithmeticMathFuncMap{
{prim::kPrimAdd->name(), &ArithmeticCpuTypeFunc<T>::Add},
{prim::kPrimSub->name(), &ArithmeticCpuTypeFunc<T>::Sub},
{prim::kPrimMul->name(), &ArithmeticCpuTypeFunc<T>::Mul},
{prim::kPrimDiv->name(), &ArithmeticCpuTypeFunc<T>::Div},
{prim::kPrimMod->name(), &ArithmeticCpuTypeFunc<T>::Mod},
{prim::kPrimFloorMod->name(), &ArithmeticCpuTypeFunc<T>::FloorMod},
{prim::kPrimPow->name(), &ArithmeticCpuTypeFunc<T>::Pow},
{prim::kPrimFloorDiv->name(), &ArithmeticCpuTypeFunc<T>::FloorDiv},
{prim::kPrimAtan2->name(), &ArithmeticCpuTypeFunc<T>::Atan2},
{prim::kPrimRealDiv->name(), &ArithmeticCpuTypeFunc<T>::RealDiv},
{prim::kPrimSquaredDifference->name(), &ArithmeticCpuTypeFunc<T>::SquaredDifference}};
if (arithmeticMathFuncMap.find(kernel_name_) == arithmeticMathFuncMap.end()) {
MS_LOG(EXCEPTION) << "For 'Arithmetic', only supports operators in " << Unorderedmap2Str(arithmeticMathFuncMap)
<< ", but got " << kernel_name_;
}
compute_func_ = arithmeticMathFuncMap.at(kernel_name_);
}
std::string kernel_name_;
size_t output_size_{1};
ArithmeticParameter op_para_{};
std::vector<size_t> input_shape1_;
std::vector<size_t> input_shape2_;
std::vector<size_t> input_element_num1_;
std::vector<size_t> input_element_num2_;
std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
using TypeComputeFunc = std::function<void(ArithmeticCpuTypeFunc *, const T *in_x, const T *in_y, T *out)>;
TypeComputeFunc compute_func_{nullptr};
};
template <typename T>
void ArithmeticCpuTypeFunc<T>::AssignAdd(T *input1, const T *input2, T *out) {
auto task = [&input1, &input2, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input1[i] + input2[i];
@ -73,7 +193,7 @@ void ArithmeticCpuKernelMod<T>::AssignAdd(T *input1, const T *input2, T *out) {
}
template <typename T>
void ArithmeticCpuKernelMod<T>::Add(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::Add(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -87,7 +207,7 @@ void ArithmeticCpuKernelMod<T>::Add(const T *input1, const T *input2, T *out) {
}
template <typename T>
void ArithmeticCpuKernelMod<T>::Sub(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::Sub(const T *input1, const T *input2, T *out) {
if constexpr (std::is_same_v<T, float>) {
if (input_shape1_ == input_shape2_) {
auto task = [this, input1, input2, out](size_t start, size_t end) {
@ -122,7 +242,7 @@ void ArithmeticCpuKernelMod<T>::Sub(const T *input1, const T *input2, T *out) {
}
template <typename T>
void ArithmeticCpuKernelMod<T>::Mul(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::Mul(const T *input1, const T *input2, T *out) {
if constexpr (std::is_same_v<T, float>) {
if (input_shape1_ == input_shape2_) {
auto task = [this, input1, input2, out](size_t start, size_t end) {
@ -160,7 +280,7 @@ void ArithmeticCpuKernelMod<T>::Mul(const T *input1, const T *input2, T *out) {
}
template <typename T>
void ArithmeticCpuKernelMod<T>::RealDiv(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::RealDiv(const T *input1, const T *input2, T *out) {
if (input_shape1_ == input_shape2_) {
auto task = [&](size_t start, size_t end) {
ElementRealDiv<T>(input1 + start, input2 + start, out + start, end - start, 1, 1);
@ -211,7 +331,7 @@ void ArithmeticCpuKernelMod<T>::RealDiv(const T *input1, const T *input2, T *out
}
template <typename T>
void ArithmeticCpuKernelMod<T>::Div(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::Div(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -240,7 +360,7 @@ void ArithmeticCpuKernelMod<T>::Div(const T *input1, const T *input2, T *out) {
}
template <typename T>
void ArithmeticCpuKernelMod<T>::FloorDiv(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::FloorDiv(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -269,7 +389,7 @@ void ArithmeticCpuKernelMod<T>::FloorDiv(const T *input1, const T *input2, T *ou
}
template <typename T>
void ArithmeticCpuKernelMod<T>::Mod(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::Mod(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -291,7 +411,7 @@ void ArithmeticCpuKernelMod<T>::Mod(const T *input1, const T *input2, T *out) {
}
template <typename T>
void ArithmeticCpuKernelMod<T>::FloorMod(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::FloorMod(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -308,7 +428,7 @@ void ArithmeticCpuKernelMod<T>::FloorMod(const T *input1, const T *input2, T *ou
}
template <typename T>
void ArithmeticCpuKernelMod<T>::Pow(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::Pow(const T *input1, const T *input2, T *out) {
if constexpr (std::is_same_v<T, float>) {
auto is_power_single = [this]() {
bool is_power_single = false;
@ -365,7 +485,7 @@ void ArithmeticCpuKernelMod<T>::Pow(const T *input1, const T *input2, T *out) {
}
template <typename T>
void ArithmeticCpuKernelMod<T>::SquaredDifference(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::SquaredDifference(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -384,7 +504,7 @@ void ArithmeticCpuKernelMod<T>::SquaredDifference(const T *input1, const T *inpu
}
template <typename T>
void ArithmeticCpuKernelMod<T>::Atan2(const T *input1, const T *input2, T *out) {
void ArithmeticCpuTypeFunc<T>::Atan2(const T *input1, const T *input2, T *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [&input1, &input2, &out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -399,87 +519,172 @@ void ArithmeticCpuKernelMod<T>::Atan2(const T *input1, const T *input2, T *out)
}
template <typename T>
void ArithmeticCpuKernelMod<T>::InitComputeFunc() {
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
return;
}
static const std::unordered_map<std::string, TypeComputeFunc> arithmeticMathFuncMap{
{prim::kPrimAdd->name(), &ArithmeticCpuKernelMod<T>::Add},
{prim::kPrimSub->name(), &ArithmeticCpuKernelMod<T>::Sub},
{prim::kPrimMul->name(), &ArithmeticCpuKernelMod<T>::Mul},
{prim::kPrimDiv->name(), &ArithmeticCpuKernelMod<T>::Div},
{prim::kPrimMod->name(), &ArithmeticCpuKernelMod<T>::Mod},
{prim::kPrimFloorMod->name(), &ArithmeticCpuKernelMod<T>::FloorMod},
{prim::kPrimPow->name(), &ArithmeticCpuKernelMod<T>::Pow},
{prim::kPrimFloorDiv->name(), &ArithmeticCpuKernelMod<T>::FloorDiv},
{prim::kPrimAtan2->name(), &ArithmeticCpuKernelMod<T>::Atan2},
{prim::kPrimRealDiv->name(), &ArithmeticCpuKernelMod<T>::RealDiv},
{prim::kPrimSquaredDifference->name(), &ArithmeticCpuKernelMod<T>::SquaredDifference}};
if (arithmeticMathFuncMap.find(kernel_name_) == arithmeticMathFuncMap.end()) {
MS_LOG(EXCEPTION) << "For 'Arithmetic', only supports operators in " << Unorderedmap2Str(arithmeticMathFuncMap)
<< ", but got " << kernel_name_;
}
compute_func_ = arithmeticMathFuncMap.at(kernel_name_);
std::shared_ptr<CpuKernelFunc> SpecializeArithFunc(const CNodePtr &kernel_node) {
return std::make_shared<ArithmeticCpuTypeFunc<T>>(kernel_node);
}
using ArithmeticCpuFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>(const CNodePtr &)>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithmeticCpuFuncCreator>>> kernel_attr_list = {
{kSub,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{kMul,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SpecializeArithFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithFunc<bool>}}},
{prim::kPrimDiv->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{prim::kPrimPow->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{kRealDiv,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpecializeArithFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{prim::kPrimFloorDiv->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
SpecializeArithFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SpecializeArithFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
SpecializeArithFunc<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpecializeArithFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{prim::kPrimMod->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>}}},
{prim::kPrimFloorMod->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpecializeArithFunc<float16>}}},
{kAssignAdd,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpecializeArithFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{prim::kPrimSquaredDifference->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpecializeArithFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpecializeArithFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}},
{prim::kPrimAtan2->name(),
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpecializeArithFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeArithFunc<double>}}}};
} // namespace
template <typename T>
void ArithmeticCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
void ArithmeticCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape_.size() == 0) {
(void)output_shape_.insert(output_shape_.begin(), 1);
if (kernel_name_ != kernel_type_) {
MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name_;
}
output_size_ = 1;
for (size_t i = 0; i < output_shape_.size(); ++i) {
output_size_ *= output_shape_[i];
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Arithmetic does not support this kernel data type: " << kernel_attr;
}
op_para_.in_elements_num0_ = 1;
for (size_t i = 0; i < input_shape1_.size(); ++i) {
op_para_.in_elements_num0_ *= input_shape1_[i];
}
op_para_.in_elements_num1_ = 1;
for (size_t i = 0; i < input_shape2_.size(); ++i) {
op_para_.in_elements_num1_ *= input_shape2_[i];
}
size_t l = input_shape1_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
(void)input_shape1_.insert(input_shape1_.begin(), 1);
}
l = input_shape2_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
(void)input_shape2_.insert(input_shape2_.begin(), 1);
}
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
InitComputeFunc();
func_obj_ = kernel_attr_list[kernel_name_][index].second(kernel_node);
}
template <typename T>
bool ArithmeticCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
if (output_size_ == 0) {
MS_LOG(WARNING) << kernel_name_ << " output shape contain 0, output_shape: " << output_shape_;
return true;
std::vector<KernelAttr> ArithmeticCpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_list.find(kernel_type_);
if (iter == kernel_attr_list.end()) {
MS_LOG(EXCEPTION) << "Arithmetic cpu does not support " << kernel_type_;
}
if (kernel_name_ == prim::kPrimAssignAdd->name()) {
AssignAdd(input1, input2, output);
} else {
compute_func_(this, input1, input2, output);
}
return true;
std::vector<KernelAttr> support_list;
std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ArithmeticCpuFuncCreator> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sub,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kSub); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mul,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kMul); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Div,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimDiv->name()); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Pow,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimPow->name()); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, RealDiv,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kRealDiv); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorDiv, []() {
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimFloorDiv->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Mod,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimMod->name()); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, FloorMod, []() {
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimFloorMod->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, AssignAdd,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(kAssignAdd); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, SquaredDifference, []() {
return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimSquaredDifference->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atan2,
[]() { return std::make_shared<ArithmeticCpuKernelMod>(prim::kPrimAtan2->name()); });
} // namespace kernel
} // namespace mindspore

View File

@ -18,216 +18,39 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/cpu/kernel/nnacl/arithmetic.h"
namespace mindspore {
namespace kernel {
template <typename T>
class ArithmeticCpuKernelMod : public NativeCpuKernelMod {
public:
ArithmeticCpuKernelMod() = default;
explicit ArithmeticCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~ArithmeticCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
const size_t kInputsNum = 2;
const size_t kOutputsNum = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
return func_obj_->RunFunc(inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitComputeFunc();
void Sub(const T *input1, const T *input2, T *out);
void Add(const T *input1, const T *input2, T *out);
void Mul(const T *input1, const T *input2, T *out);
void RealDiv(const T *input1, const T *input2, T *out);
void Div(const T *input1, const T *input2, T *out);
void FloorDiv(const T *input1, const T *input2, T *out);
void Mod(const T *input1, const T *input2, T *out);
void FloorMod(const T *input1, const T *input2, T *out);
void Pow(const T *input1, const T *input2, T *out);
void AssignAdd(T *input1, const T *input2, T *out);
void Atan2(const T *input1, const T *input2, T *out);
void SquaredDifference(const T *input1, const T *input2, T *out);
using TypeComputeFunc = std::function<void(ArithmeticCpuKernelMod *, const T *in_x, const T *in_y, T *out)>;
TypeComputeFunc compute_func_{nullptr};
size_t output_size_{1};
ArithmeticParameter op_para_{};
std::vector<size_t> input_shape1_;
std::vector<size_t> input_shape2_;
std::vector<size_t> input_element_num1_;
std::vector<size_t> input_element_num2_;
std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
std::shared_ptr<CpuKernelFunc> func_obj_;
std::string kernel_type_{"Unknown"};
};
MS_REG_CPU_KERNEL_T(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Sub, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ArithmeticCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(
Mul, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Div, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
Pow, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Pow, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ArithmeticCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
RealDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
ArithmeticCpuKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ArithmeticCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
ArithmeticCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
ArithmeticCpuKernelMod, uint16_t);
MS_REG_CPU_KERNEL_T(
Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
Mod, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Mod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
FloorMod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
FloorMod,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ArithmeticCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
AssignAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
AssignAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
SquaredDifference,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
SquaredDifference,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ArithmeticCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(
SquaredDifference,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
SquaredDifference,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
Atan2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Atan2,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticCpuKernelMod, double);
} // namespace kernel
} // namespace mindspore

View File

@ -20,6 +20,9 @@
#include <cmath>
#include <unordered_map>
#include <functional>
#include <map>
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
@ -27,13 +30,104 @@ namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kMaxLessSerialSize = 15000;
constexpr size_t kInputsNum = 2;
constexpr size_t kOutputsNum = 1;
} // namespace
constexpr auto kEqual = "Equal";
constexpr auto kNotEqual = "NotEqual";
template <typename T>
class ArithLogicCpuTypeFunc : public CpuKernelFunc {
public:
ArithLogicCpuTypeFunc() = default;
~ArithLogicCpuTypeFunc() override = default;
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
const auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
compute_func_(this, input1, input2, output);
return true;
}
void InitFunc(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape_.empty()) {
(void)output_shape_.insert(output_shape_.begin(), 1);
}
output_size_ = 1;
for (size_t i = 0; i < output_shape_.size(); ++i) {
output_size_ *= output_shape_[i];
}
size_t l = input_shape1_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
(void)input_shape1_.insert(input_shape1_.begin(), 1);
}
l = input_shape2_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
(void)input_shape2_.insert(input_shape2_.begin(), 1);
}
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
auto dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
if (dtype_ != dtype_1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the 'input1' and 'input2' should have the same data type, but got type of 'input1': "
<< dtype_ << ", and the type of 'input2': " << dtype_1;
}
static const std::unordered_map<std::string, TypeComputeFunc> arithmetic_logic_func_map{
{prim::kPrimGreater->name(), &ArithLogicCpuTypeFunc<T>::Greater},
{prim::kPrimGreaterEqual->name(), &ArithLogicCpuTypeFunc<T>::GreaterEqual},
{prim::kPrimLogicalAnd->name(), &ArithLogicCpuTypeFunc<T>::LogicalAnd},
{prim::kPrimLessEqual->name(), &ArithLogicCpuTypeFunc<T>::LessEqual},
{prim::kPrimLogicalOr->name(), &ArithLogicCpuTypeFunc<T>::LogicalOr},
{prim::kPrimLogicalXor->name(), &ArithLogicCpuTypeFunc<T>::LogicalXor},
{prim::kPrimLess->name(), &ArithLogicCpuTypeFunc<T>::Less},
{prim::kPrimNotEqual->name(), &ArithLogicCpuTypeFunc<T>::NotEqual},
{prim::kPrimEqual->name(), &ArithLogicCpuTypeFunc<T>::Equal}};
if (arithmetic_logic_func_map.find(kernel_name_) == arithmetic_logic_func_map.end()) {
MS_LOG(EXCEPTION) << "For 'ArithmeticLogic', only supports operators in "
<< Unorderedmap2Str(arithmetic_logic_func_map) << ", but got " << kernel_name_;
}
compute_func_ = arithmetic_logic_func_map.at(kernel_name_);
}
private:
template <typename Op>
void BinaryOp(const T *input1, const T *input2, bool *out, Op op);
void Less(const T *input1, const T *input2, bool *out);
void Equal(const T *input1, const T *input2, bool *out);
void NotEqual(const T *input1, const T *input2, bool *out);
void Greater(const T *input1, const T *input2, bool *out);
void GreaterEqual(const T *input1, const T *input2, bool *out);
void LessEqual(const T *input1, const T *input2, bool *out);
void LogicalAnd(const T *input1, const T *input2, bool *out);
void LogicalOr(const T *input1, const T *input2, bool *out);
void LogicalXor(const T *input1, const T *input2, bool *out);
using TypeComputeFunc = std::function<void(ArithLogicCpuTypeFunc *, const T *, const T *, bool *)>;
TypeComputeFunc compute_func_{nullptr};
std::string kernel_name_;
size_t output_size_{1};
TypeId dtype_{kTypeUnknown};
std::vector<size_t> input_shape1_;
std::vector<size_t> input_shape2_;
std::vector<size_t> input_element_num1_;
std::vector<size_t> input_element_num2_;
std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
};
template <typename T>
template <typename Op>
void ArithmeticLogicCpuKernelMod<T>::BinaryOp(const T *input1, const T *input2, bool *out, Op op) {
void ArithLogicCpuTypeFunc<T>::BinaryOp(const T *input1, const T *input2, bool *out, Op op) {
size_t input1_size = 1;
size_t input2_size = 2;
@ -80,32 +174,32 @@ void ArithmeticLogicCpuKernelMod<T>::BinaryOp(const T *input1, const T *input2,
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::Less(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::Less(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::less<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::Equal(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::Equal(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::equal_to<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::NotEqual(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::NotEqual(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::not_equal_to<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::LogicalAnd(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::LogicalAnd(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::logical_and<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::LogicalOr(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::LogicalOr(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::logical_or<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::LogicalXor(const T *input1, const T *input2, bool *out) {
BroadcastIterator base_iter(input_shape1_, input_shape2_, output_shape_);
auto task = [input1, input2, out, &base_iter](size_t start, size_t end) {
auto iter = base_iter;
@ -119,87 +213,172 @@ void ArithmeticLogicCpuKernelMod<T>::LogicalXor(const T *input1, const T *input2
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::Greater(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::Greater(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::greater<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::GreaterEqual(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::GreaterEqual(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::greater_equal<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::LessEqual(const T *input1, const T *input2, bool *out) {
void ArithLogicCpuTypeFunc<T>::LessEqual(const T *input1, const T *input2, bool *out) {
BinaryOp(input1, input2, out, std::less_equal<T>());
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::InitComputeFunc() {
static const std::unordered_map<std::string, TypeComputeFunc> arithmeticLogicFuncMap{
{prim::kPrimGreater->name(), &ArithmeticLogicCpuKernelMod<T>::Greater},
{prim::kPrimGreaterEqual->name(), &ArithmeticLogicCpuKernelMod<T>::GreaterEqual},
{prim::kPrimLogicalAnd->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalAnd},
{prim::kPrimLessEqual->name(), &ArithmeticLogicCpuKernelMod<T>::LessEqual},
{prim::kPrimLogicalOr->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalOr},
{prim::kPrimLogicalXor->name(), &ArithmeticLogicCpuKernelMod<T>::LogicalXor},
{prim::kPrimLess->name(), &ArithmeticLogicCpuKernelMod<T>::Less},
{prim::kPrimNotEqual->name(), &ArithmeticLogicCpuKernelMod<T>::NotEqual},
{prim::kPrimEqual->name(), &ArithmeticLogicCpuKernelMod<T>::Equal}};
if (arithmeticLogicFuncMap.find(kernel_name_) == arithmeticLogicFuncMap.end()) {
MS_LOG(EXCEPTION) << "For 'ArithmeticLogic', only supports operators in "
<< Unorderedmap2Str(arithmeticLogicFuncMap) << ", but got " << kernel_name_;
std::shared_ptr<CpuKernelFunc> SpecializeArithLogFunc() {
return std::make_shared<ArithLogicCpuTypeFunc<T>>();
}
using ArithLogicCpuFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithLogicCpuFuncCreator>>> kernel_attr_lists = {
{prim::kPrimLess->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{kEqual,
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{kNotEqual,
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{prim::kPrimGreater->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>}}},
{prim::kPrimGreaterEqual->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{prim::kPrimLessEqual->name(),
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<double>}}},
{prim::kPrimLogicalAnd->name(),
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>}}},
{prim::kPrimLogicalOr->name(),
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>}}},
{prim::kPrimLogicalXor->name(),
{{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SpecializeArithLogFunc<bool>}}}};
} // namespace
void ArithmeticLogicCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name != kernel_type_) {
MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name;
}
compute_func_ = arithmeticLogicFuncMap.at(kernel_name_);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Arithmetic logic does not support this kernel data type: " << kernel_attr;
}
func_obj_ = kernel_attr_lists[kernel_name][index].second();
func_obj_->InitFunc(kernel_node);
}
template <typename T>
void ArithmeticLogicCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape1_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_shape2_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape_.empty()) {
(void)output_shape_.insert(output_shape_.begin(), 1);
std::vector<KernelAttr> ArithmeticLogicCpuKernelMod::GetOpSupport() {
auto iter = kernel_attr_lists.find(kernel_type_);
if (iter == kernel_attr_lists.end()) {
MS_LOG(EXCEPTION) << "Arithmetic logic cpu does not support " << kernel_type_;
}
output_size_ = 1;
for (size_t i = 0; i < output_shape_.size(); ++i) {
output_size_ *= output_shape_[i];
}
std::vector<KernelAttr> support_list;
std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ArithLogicCpuFuncCreator> &pair) { return pair.first; });
size_t l = input_shape1_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
(void)input_shape1_.insert(input_shape1_.begin(), 1);
}
l = input_shape2_.size();
for (size_t i = 0; i < output_shape_.size() - l; ++i) {
(void)input_shape2_.insert(input_shape2_.begin(), 1);
}
CPUKernelUtils::GetElementNumEveryDim(input_shape1_, &input_element_num1_);
CPUKernelUtils::GetElementNumEveryDim(input_shape2_, &input_element_num2_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
auto dtype_1 = AnfAlgo::GetInputDeviceDataType(kernel_node, 1);
if (dtype_ != dtype_1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the 'input1' and 'input2' should have the same data type, but got type of 'input1': "
<< dtype_ << ", and the type of 'input2': " << dtype_1;
}
InitComputeFunc();
return support_list;
}
template <typename T>
bool ArithmeticLogicCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> & /* workspace */,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
const auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
compute_func_(this, input1, input2, output);
return true;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Less, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLess->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Equal,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kEqual); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, NotEqual,
[]() { return std::make_shared<ArithmeticLogicCpuKernelMod>(kNotEqual); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Greater, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimGreater->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, GreaterEqual, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimGreaterEqual->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LessEqual, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLessEqual->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalAnd, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalAnd->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalOr, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalOr->name());
});
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalXor, []() {
return std::make_shared<ArithmeticLogicCpuKernelMod>(prim::kPrimLogicalXor->name());
});
} // namespace kernel
} // namespace mindspore

View File

@ -18,187 +18,39 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ARITHMETIC_LOGIC_CPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include <limits>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class ArithmeticLogicCpuKernelMod : public NativeCpuKernelMod {
public:
ArithmeticLogicCpuKernelMod() = default;
explicit ArithmeticLogicCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~ArithmeticLogicCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
const size_t kInputsNum = 2;
const size_t kOutputsNum = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
return func_obj_->RunFunc(inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitComputeFunc();
template <typename Op>
void BinaryOp(const T *input1, const T *input2, bool *out, Op op);
void Less(const T *input1, const T *input2, bool *out);
void Equal(const T *input1, const T *input2, bool *out);
void NotEqual(const T *input1, const T *input2, bool *out);
void Greater(const T *input1, const T *input2, bool *out);
void GreaterEqual(const T *input1, const T *input2, bool *out);
void LessEqual(const T *input1, const T *input2, bool *out);
void LogicalAnd(const T *input1, const T *input2, bool *out);
void LogicalOr(const T *input1, const T *input2, bool *out);
void LogicalXor(const T *input1, const T *input2, bool *out);
using TypeComputeFunc = std::function<void(ArithmeticLogicCpuKernelMod *, const T *, const T *, bool *)>;
TypeComputeFunc compute_func_{nullptr};
size_t output_size_{1};
TypeId dtype_{kTypeUnknown};
std::vector<size_t> input_shape1_;
std::vector<size_t> input_shape2_;
std::vector<size_t> input_element_num1_;
std::vector<size_t> input_element_num2_;
std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
std::shared_ptr<CpuKernelFunc> func_obj_;
std::string kernel_type_{"Unknown"};
};
MS_REG_CPU_KERNEL_T(
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Less, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int16_t);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, uint16_t);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, uint32_t);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Equal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int16_t);
MS_REG_CPU_KERNEL_T(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, uint16_t);
MS_REG_CPU_KERNEL_T(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, uint32_t);
MS_REG_CPU_KERNEL_T(
NotEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(
NotEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
NotEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(
LogicalAnd, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(
LogicalXor, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticLogicCpuKernelMod, bool);
} // namespace kernel
} // namespace mindspore

View File

@ -19,22 +19,75 @@
#include <algorithm>
#include <cmath>
#include <complex>
#include <map>
#include <string>
#include <thread>
#include <unordered_map>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
#include "plugin/device/cpu/kernel/mkldnn/eltwise_cpu_kernel.h"
namespace mindspore {
namespace kernel {
namespace {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
constexpr float kMaxNegSerialSize = 5000.0f;
constexpr float kMaxSquareSerialSize = 5000.0f;
constexpr size_t kInputsNum = 1;
constexpr size_t kOutputsNum = 1;
constexpr auto kSquare = "Square";
constexpr auto kNeg = "Neg";
constexpr auto kSign = "Sign";
constexpr auto kFloor = "Floor";
constexpr auto kRint = "Rint";
constexpr auto kRound = "Round";
constexpr auto kReciprocal = "Reciprocal";
constexpr auto kGeLU = "GeLU";
constexpr auto kLogicalNot = "LogicalNot";
constexpr auto kAsin = "Asin";
constexpr auto kACos = "ACos";
constexpr auto kAtan = "Atan";
constexpr auto kSin = "Sin";
constexpr auto kCos = "Cos";
constexpr auto kTan = "Tan";
constexpr auto kSinh = "Sinh";
constexpr auto kCosh = "Cosh";
constexpr auto kAsinh = "Asinh";
constexpr auto kAcosh = "Acosh";
constexpr auto kAtanh = "Atanh";
constexpr auto kAbs = "Abs";
constexpr auto kSqrt = "Sqrt";
constexpr auto kRsqrt = "Rsqrt";
class ArithmeticSelfCpuKernelFunc : public CpuKernelFunc {
public:
ArithmeticSelfCpuKernelFunc() = default;
~ArithmeticSelfCpuKernelFunc() override = default;
void InitFunc(const CNodePtr &kernel_node) override;
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void LaunchLogicalNot(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchKernelComplex(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown};
std::string kernel_name_;
};
template <typename T>
void Square(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Square(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = in[i] * in[i];
@ -44,7 +97,7 @@ void Square(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t siz
}
template <typename T>
void Sign(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Sign(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
if (in[i] < 0) {
@ -60,7 +113,7 @@ void Sign(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Neg(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Neg(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = -in[i];
@ -69,7 +122,7 @@ void Neg(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
ParallelLaunch(task, size, kMaxNegSerialSize);
}
void LogicalNot(ArithmeticSelfCpuKernelMod *content, const bool *in, bool *out, size_t size) {
void LogicalNot(ArithmeticSelfCpuKernelFunc *content, const bool *in, bool *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = !in[i];
@ -79,7 +132,7 @@ void LogicalNot(ArithmeticSelfCpuKernelMod *content, const bool *in, bool *out,
}
template <typename T>
void Floor(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Floor(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(floor(in[i]));
@ -89,7 +142,7 @@ void Floor(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
}
template <typename T>
void Rint(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Rint(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(rint(in[i]));
@ -99,7 +152,7 @@ void Rint(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Round(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Round(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(nearbyint(in[i]));
@ -109,7 +162,7 @@ void Round(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
}
template <typename T>
void Reciprocal(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Reciprocal(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(1.0 / in[i]);
@ -119,7 +172,7 @@ void Reciprocal(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t
}
template <typename T>
void Gelu(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Gelu(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
auto factor_a = static_cast<T>(0.7978845608);
auto factor_b = static_cast<T>(0.044715);
@ -134,7 +187,7 @@ void Gelu(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Asin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Asin(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(asin(static_cast<double>(in[i])));
@ -144,7 +197,7 @@ void Asin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void ACos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void ACos(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(acos(static_cast<double>(in[i])));
@ -154,7 +207,7 @@ void ACos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Atan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Atan(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(atan(static_cast<double>(in[i])));
@ -164,7 +217,7 @@ void Atan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Sin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Sin(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(sin(static_cast<double>(in[i])));
@ -174,7 +227,7 @@ void Sin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Cos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Cos(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(cos(static_cast<double>(in[i])));
@ -184,7 +237,7 @@ void Cos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void ComplexSin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void ComplexSin(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(sin(in[i]));
@ -194,7 +247,7 @@ void ComplexSin(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t
}
template <typename T>
void ComplexSinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void ComplexSinh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(sinh(in[i]));
@ -204,7 +257,7 @@ void ComplexSinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_
}
template <typename T>
void ComplexCos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void ComplexCos(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(cos(in[i]));
@ -214,7 +267,7 @@ void ComplexCos(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t
}
template <typename T>
void ComplexCosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void ComplexCosh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(cosh(in[i]));
@ -224,7 +277,7 @@ void ComplexCosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_
}
template <typename T>
void Tan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Tan(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(tan(static_cast<double>(in[i])));
@ -234,7 +287,7 @@ void Tan(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Sinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Sinh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(sinh(static_cast<double>(in[i])));
@ -244,7 +297,7 @@ void Sinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Cosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Cosh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(cosh(static_cast<double>(in[i])));
@ -254,7 +307,7 @@ void Cosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void ComplexAsinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void ComplexAsinh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(asinh(in[i]));
@ -264,7 +317,7 @@ void ComplexAsinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size
}
template <typename T>
void Asinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Asinh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(asinh(static_cast<double>(in[i])));
@ -274,7 +327,7 @@ void Asinh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
}
template <typename T>
void ComplexAcosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void ComplexAcosh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(acosh(in[i]));
@ -284,7 +337,7 @@ void ComplexAcosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size
}
template <typename T>
void Acosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Acosh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(acosh(static_cast<double>(in[i])));
@ -294,7 +347,7 @@ void Acosh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
}
template <typename T>
void Atanh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Atanh(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(atanh(static_cast<double>(in[i])));
@ -304,7 +357,7 @@ void Atanh(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size
}
template <typename T>
void Abs(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Abs(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = abs(in[i]);
@ -314,7 +367,7 @@ void Abs(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Sqrt(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Sqrt(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(sqrt(in[i]));
@ -324,7 +377,7 @@ void Sqrt(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size)
}
template <typename T>
void Rsqrt(ArithmeticSelfCpuKernelMod *content, const T *in, T *out, size_t size) {
void Rsqrt(ArithmeticSelfCpuKernelFunc *content, const T *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(1) / sqrt(in[i]);
@ -337,17 +390,41 @@ template <typename T>
void Identity(const T *in, T *out, size_t size) {
(void)std::copy(in, in + size, out);
}
} // namespace
void ArithmeticSelfCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
template <typename T>
bool IdentityCpuFunc(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
T *input = reinterpret_cast<T *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
Identity<T>(input, output, lens);
return true;
}
static std::vector<std::pair<KernelAttr, LaunchFunc>> identity_kernel_attr_lists = {
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), IdentityCpuFunc<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), IdentityCpuFunc<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), IdentityCpuFunc<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), IdentityCpuFunc<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), IdentityCpuFunc<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), IdentityCpuFunc<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), IdentityCpuFunc<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), IdentityCpuFunc<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), IdentityCpuFunc<complex64>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), IdentityCpuFunc<complex128>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), IdentityCpuFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), IdentityCpuFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), IdentityCpuFunc<float16>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), IdentityCpuFunc<bool>}};
void ArithmeticSelfCpuKernelFunc::InitFunc(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
}
bool ArithmeticSelfCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ArithmeticSelfCpuKernelFunc::RunFunc(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat16) {
@ -373,8 +450,8 @@ bool ArithmeticSelfCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &i
return true;
}
void ArithmeticSelfCpuKernelMod::LaunchLogicalNot(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
void ArithmeticSelfCpuKernelFunc::LaunchLogicalNot(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto *input = reinterpret_cast<bool *>(inputs[0]->addr);
auto *output = reinterpret_cast<bool *>(outputs[0]->addr);
size_t lens = outputs[0]->size / sizeof(bool);
@ -382,13 +459,13 @@ void ArithmeticSelfCpuKernelMod::LaunchLogicalNot(const std::vector<AddressPtr>
}
template <typename T>
void ArithmeticSelfCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
void ArithmeticSelfCpuKernelFunc::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
const size_t lens = outputs[0]->size / sizeof(T);
static const std::unordered_map<std::string,
std::function<void(ArithmeticSelfCpuKernelMod *, const T *, T *, size_t)>>
std::function<void(ArithmeticSelfCpuKernelFunc *, const T *, T *, size_t)>>
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>},
{prim::kPrimSign->name(), Sign<T>},
{prim::kPrimNeg->name(), Neg<T>},
@ -421,13 +498,13 @@ void ArithmeticSelfCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inp
}
template <typename T>
void ArithmeticSelfCpuKernelMod::LaunchKernelComplex(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
void ArithmeticSelfCpuKernelFunc::LaunchKernelComplex(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
const size_t lens = outputs[0]->size / sizeof(T);
static const std::unordered_map<std::string,
std::function<void(ArithmeticSelfCpuKernelMod *, const T *, T *, size_t)>>
std::function<void(ArithmeticSelfCpuKernelFunc *, const T *, T *, size_t)>>
arithmeticSelfFuncMap{{prim::kPrimSquare->name(), Square<T>}, {prim::kPrimAcosh->name(), ComplexAcosh<T>},
{prim::kPrimAsinh->name(), ComplexAsinh<T>}, {prim::kPrimNeg->name(), Neg<T>},
{prim::kPrimSinh->name(), ComplexSinh<T>}, {prim::kPrimCosh->name(), ComplexCosh<T>},
@ -435,22 +512,223 @@ void ArithmeticSelfCpuKernelMod::LaunchKernelComplex(const std::vector<AddressPt
{prim::kPrimRsqrt->name(), Rsqrt<T>}};
const auto func_pair = arithmeticSelfFuncMap.find(kernel_name_);
if (arithmeticSelfFuncMap.find(kernel_name_) == arithmeticSelfFuncMap.end()) {
MS_LOG(EXCEPTION) << "ArithmeticSelfCpuKernelMod does not support " << kernel_name_;
MS_LOG(EXCEPTION) << "ArithmeticSelfCpuKernelFunc does not support " << kernel_name_;
}
func_pair->second(this, input, output, lens);
}
template <typename T>
bool IdentityCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kOutputsNum, kernel_name_);
T *input = reinterpret_cast<T *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
Identity<T>(input, output, lens);
return true;
// MKLDNN Sqrt
class SqrtMKLKernelFunc : public CpuKernelFunc, private EltWiseCpuKernelMod {
public:
SqrtMKLKernelFunc() : EltWiseCpuKernelMod(kSqrt) {}
~SqrtMKLKernelFunc() override = default;
void InitFunc(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name != kSqrt) {
MS_LOG(EXCEPTION) << "Should be " << kSqrt << ", but got " << kernel_name;
}
EltWiseCpuKernelMod::InitKernel(kernel_node);
}
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return EltWiseCpuKernelMod::Launch(inputs, workspace, outputs);
}
};
std::shared_ptr<CpuKernelFunc> CreateArithSelfFunc() { return std::make_shared<ArithmeticSelfCpuKernelFunc>(); }
using ArithFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
static std::map<std::string, std::vector<std::pair<KernelAttr, ArithFuncCreator>>> arith_kernel_attr_list_map = {
{kRsqrt,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kSquare,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kNeg,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kSign,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kFloor,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kRint,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kRound,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kReciprocal,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kGeLU, {{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc}}},
{kLogicalNot, {{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateArithSelfFunc}}},
{kAsin,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kACos,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kAtan,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kSin,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kCos,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kTan,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kSinh,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kCosh,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kAsinh,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc}}},
{kAcosh,
{{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kAtanh,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kAbs,
{{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc}}},
{kSqrt,
{{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateArithSelfFunc},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
[]() { return std::make_shared<SqrtMKLKernelFunc>(); }}}}};
} // namespace
void ArithmeticSelfCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name_ != kernel_type_) {
MS_LOG(EXCEPTION) << "Need to be " << kernel_type_ << " but got kernel name as " << kernel_name_;
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "ArithmeticSelf cpu does not support this kernel data type: " << kernel_attr;
}
func_obj_ = arith_kernel_attr_list_map[kernel_name_][index].second();
func_obj_->InitFunc(kernel_node);
}
std::vector<KernelAttr> ArithmeticSelfCpuKernelMod::GetOpSupport() {
auto iter = arith_kernel_attr_list_map.find(kernel_type_);
if (iter == arith_kernel_attr_list_map.end()) {
MS_LOG(EXCEPTION) << "Arithmetic self cpu does not support " << kernel_type_;
}
std::vector<KernelAttr> support_list;
std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ArithFuncCreator> &pair) { return pair.first; });
return support_list;
}
void IdentityCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Identity cpu does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = identity_kernel_attr_lists[index].second;
}
std::vector<KernelAttr> IdentityCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> kernel_attr_list;
std::transform(identity_kernel_attr_lists.begin(), identity_kernel_attr_lists.end(),
std::back_inserter(kernel_attr_list),
[](const std::pair<KernelAttr, LaunchFunc> &pair) { return pair.first; });
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Rsqrt,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kRsqrt); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Square,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSquare); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Neg,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kNeg); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sign,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSign); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Floor,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kFloor); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Rint,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kRint); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Round,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kRound); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Reciprocal,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kReciprocal); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, GeLU,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kGeLU); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, LogicalNot,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kLogicalNot); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Asin,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAsin); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, ACos,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kACos); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atan,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAtan); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sin,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSin); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Cos,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kCos); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Tan,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kTan); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sinh,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSinh); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Cosh,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kCosh); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Asinh,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAsinh); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Acosh,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAcosh); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Atanh,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAtanh); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Abs,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kAbs); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, Sqrt,
[]() { return std::make_shared<ArithmeticSelfCpuKernelMod>(kSqrt); });
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Identity, IdentityCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -20,211 +20,56 @@
#include <complex>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
using complex64 = std::complex<float>;
using complex128 = std::complex<double>;
constexpr auto kUnknown = "Unknown";
class ArithmeticSelfCpuKernelMod : public NativeCpuKernelMod {
public:
ArithmeticSelfCpuKernelMod() = default;
explicit ArithmeticSelfCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~ArithmeticSelfCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return func_obj_->RunFunc(inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
void LaunchLogicalNot(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
void LaunchKernelComplex(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown};
std::shared_ptr<CpuKernelFunc> func_obj_;
std::string kernel_type_{kUnknown};
};
template <typename T>
class IdentityCpuKernelMod : public ArithmeticSelfCpuKernelMod {
using LaunchFunc = std::function<bool(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
class IdentityCpuKernelMod : public NativeCpuKernelMod {
public:
IdentityCpuKernelMod() = default;
~IdentityCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Square, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Neg, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sign, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Floor, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Rint, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Round, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Round, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Reciprocal, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(GeLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
MS_REG_CPU_KERNEL(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ArithmeticSelfCpuKernelMod);
const std::vector<AddressPtr> &outputs) override {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_);
return kernel_func_(inputs, outputs);
}
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
IdentityCpuKernelMod, uint64_t);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
IdentityCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
IdentityCpuKernelMod, uint32_t);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
IdentityCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
IdentityCpuKernelMod, uint16_t);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
IdentityCpuKernelMod, int16_t);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
IdentityCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
IdentityCpuKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
IdentityCpuKernelMod, complex64);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
IdentityCpuKernelMod, complex128);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
IdentityCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
IdentityCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
IdentityCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(Identity, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
IdentityCpuKernelMod, bool);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
LaunchFunc kernel_func_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -96,5 +96,25 @@ bool AssignCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std
ParallelLaunchAutoSearch(task, total_size, this, &parallel_search_info_);
return true;
}
std::vector<KernelAttr> AssignCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Assign, AssignCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -22,7 +22,7 @@
#include <unordered_map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -36,62 +36,14 @@ class AssignCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
size_t batch_size_{1};
size_t input_x_dtype_size_{4};
TypeId input_x_dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AssignCpuKernelMod);
MS_REG_CPU_KERNEL(
Assign,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
AssignCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -103,5 +103,7 @@ bool BiasAddCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const st
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BiasAdd, BiasAddCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -39,7 +39,6 @@ class BiasAddCpuKernelMod : public NativeCpuKernelMod {
std::vector<size_t> input_shape_;
std::vector<size_t> bias_shape_;
};
MS_REG_CPU_KERNEL(BiasAdd, KernelAttr(), BiasAddCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_

View File

@ -70,5 +70,7 @@ bool BiasAddGradCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, cons
}
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BiasAddGrad, BiasAddGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -37,7 +37,6 @@ class BiasAddGradCpuKernelMod : public NativeCpuKernelMod {
private:
std::vector<size_t> input_shape_;
};
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr(), BiasAddGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_GRAD_CPU_KERNEL_H_

View File

@ -129,5 +129,25 @@ void BinaryCrossEntropyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
<< reduction;
}
}
std::vector<KernelAttr> BinaryCrossEntropyCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BinaryCrossEntropy, BinaryCrossEntropyCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -36,6 +36,9 @@ class BinaryCrossEntropyCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
void LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss) const;
@ -47,28 +50,6 @@ class BinaryCrossEntropyCpuKernelMod : public NativeCpuKernelMod {
ReductionType reduction_{kNone};
bool weight_defined_{false}; // true: there are 3 inputs, false: there are 2 inputs(no [weight])
};
MS_REG_CPU_KERNEL(BinaryCrossEntropy,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyCpuKernelMod);
MS_REG_CPU_KERNEL(BinaryCrossEntropy,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyCpuKernelMod);
MS_REG_CPU_KERNEL(
BinaryCrossEntropy,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyCpuKernelMod);
MS_REG_CPU_KERNEL(
BinaryCrossEntropy,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H

View File

@ -108,5 +108,24 @@ void BinaryCrossEntropyGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node)
<< reduction;
}
}
std::vector<KernelAttr> BinaryCrossEntropyGradCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> kernel_attr_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)};
return kernel_attr_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BinaryCrossEntropyGrad, BinaryCrossEntropyGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/cpu/kernel/binary_cross_entropy_cpu_kernel.h"
namespace mindspore {
@ -35,6 +35,9 @@ class BinaryCrossEntropyGradCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
@ -44,36 +47,6 @@ class BinaryCrossEntropyGradCpuKernelMod : public NativeCpuKernelMod {
ReductionType reduction_{kNone};
bool weight_defined_{false}; // true: there are 4 inputs, false: there are 3 inputs(no [weight])
};
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGradCpuKernelMod);
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyGradCpuKernelMod);
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGradCpuKernelMod);
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyGradCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H

View File

@ -15,12 +15,12 @@
*/
#include "plugin/device/cpu/kernel/boundingbox_decode_cpu_kernel.h"
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void BoundingBoxDecodeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void BoundingBoxDecodeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
@ -71,12 +71,13 @@ void BoundingBoxDecodeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the length of 'max_shape' should be at least 2, but got: " << max_shape_.size();
}
InitTaskFunc(kernel_node);
}
template <typename T>
bool BoundingBoxDecodeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool BoundingBoxDecodeCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
auto anchor_box = reinterpret_cast<T *>(inputs[0]->addr);
auto deltas = reinterpret_cast<T *>(inputs[1]->addr);
auto bboxes = reinterpret_cast<T *>(outputs[0]->addr);
@ -156,5 +157,30 @@ bool BoundingBoxDecodeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressP
ParallelLaunchAutoSearch(task, elem_num, this, &parallel_search_info_);
return true;
}
std::vector<std::pair<KernelAttr, BoundingBoxDecodeCpuKernelMod::BoundingBoxDecodeFunc>>
BoundingBoxDecodeCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&BoundingBoxDecodeCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&BoundingBoxDecodeCpuKernelMod::LaunchKernel<float16>}};
void BoundingBoxDecodeCpuKernelMod::InitTaskFunc(const CNodePtr &kernel_node) {
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "BoundingBoxDecode does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
std::vector<KernelAttr> BoundingBoxDecodeCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BoundingBoxDecodeFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BoundingBoxDecode, BoundingBoxDecodeCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,15 +19,15 @@
#include <vector>
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
constexpr size_t MIN_MAX_SHAPE_SIZE = 2;
constexpr size_t INPUT_NUMS = 2;
template <typename T>
class BoundingBoxDecodeCpuKernelMod : public NativeCpuKernelMod {
public:
BoundingBoxDecodeCpuKernelMod() = default;
@ -35,26 +35,30 @@ class BoundingBoxDecodeCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitTaskFunc(const CNodePtr &kernel_node);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using BoundingBoxDecodeFunc =
std::function<bool(BoundingBoxDecodeCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, BoundingBoxDecodeFunc>> func_list_;
BoundingBoxDecodeFunc kernel_func_;
std::vector<float> means_;
std::vector<float> stds_;
std::vector<int> max_shape_;
float wh_ratio_clip_{0.016};
};
MS_REG_CPU_KERNEL_T(
BoundingBoxDecode,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BoundingBoxDecodeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
BoundingBoxDecode,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BoundingBoxDecodeCpuKernelMod, float16);
} // namespace kernel
} // namespace mindspore

View File

@ -15,12 +15,35 @@
*/
#include "plugin/device/cpu/kernel/boundingbox_encode_cpu_kernel.h"
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void BoundingBoxEncodeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
std::vector<std::pair<KernelAttr, BoundingBoxEncodeCpuKernelMod::BoundingBoxEncodeFunc>>
BoundingBoxEncodeCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&BoundingBoxEncodeCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&BoundingBoxEncodeCpuKernelMod::LaunchKernel<float16>}};
void BoundingBoxEncodeCpuKernelMod::InitTaskFunc(const CNodePtr &kernel_node) {
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "BoundingBoxEncode does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
std::vector<KernelAttr> BoundingBoxEncodeCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BoundingBoxEncodeFunc> &pair) { return pair.first; });
return support_list;
}
void BoundingBoxEncodeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
@ -61,12 +84,12 @@ void BoundingBoxEncodeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
"but got the length of 'means': "
<< means_.size() << ", and the length of 'stds': " << stds_.size();
}
InitTaskFunc(kernel_node);
}
template <typename T>
bool BoundingBoxEncodeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool BoundingBoxEncodeCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
auto anchor_box = reinterpret_cast<T *>(inputs[0]->addr);
auto groundtruth_box = reinterpret_cast<T *>(inputs[1]->addr);
auto deltas = reinterpret_cast<T *>(outputs[0]->addr);
@ -128,5 +151,7 @@ bool BoundingBoxEncodeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressP
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BoundingBoxEncode, BoundingBoxEncodeCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,14 +19,14 @@
#include <vector>
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
constexpr size_t INPUT_NUMS = 2;
template <typename T>
class BoundingBoxEncodeCpuKernelMod : public NativeCpuKernelMod {
public:
BoundingBoxEncodeCpuKernelMod() = default;
@ -34,24 +34,28 @@ class BoundingBoxEncodeCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using BoundingBoxEncodeFunc =
std::function<bool(BoundingBoxEncodeCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, BoundingBoxEncodeFunc>> func_list_;
BoundingBoxEncodeFunc kernel_func_;
void InitTaskFunc(const CNodePtr &kernel_node);
std::vector<float> means_;
std::vector<float> stds_;
};
MS_REG_CPU_KERNEL_T(
BoundingBoxEncode,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BoundingBoxEncodeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
BoundingBoxEncode,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BoundingBoxEncodeCpuKernelMod, float16);
} // namespace kernel
} // namespace mindspore

View File

@ -15,6 +15,8 @@
*/
#include "plugin/device/cpu/kernel/broadcast_to_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
namespace mindspore {
@ -23,8 +25,43 @@ namespace {
constexpr size_t kBroadcastToOutputsNum = 1;
} // namespace
template <typename T>
void BroadcastToCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
std::map<std::string, std::vector<std::pair<KernelAttr, BroadcastToCpuKernelMod::BroadcastToFunc>>>
BroadcastToCpuKernelMod::func_list_ = {
{kBroadcastTo,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&BroadcastToCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&BroadcastToCpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&BroadcastToCpuKernelMod::LaunchKernel<bool>}}},
{kDynamicBroadcastTo,
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&BroadcastToCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&BroadcastToCpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&BroadcastToCpuKernelMod::LaunchKernel<bool>}}}};
void BroadcastToCpuKernelMod::InitTaskFunc(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name_ != kernel_type_) {
MS_LOG(EXCEPTION) << "Suppose to be " << kernel_type_ << " but got " << kernel_name_;
}
auto iter = func_list_.find(kernel_type_);
if (iter == func_list_.end()) {
MS_LOG(EXCEPTION) << "BroadcastTo cpu does not support " << kernel_type_;
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "BroadcastTo does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[kernel_type_][index].second;
}
void BroadcastToCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
@ -40,10 +77,11 @@ void BroadcastToCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
}
shape_info_.input_shape_size_ = SizeToInt(input_shape_size);
shape_info_.output_shape_size_ = SizeToInt(output_shape_size);
InitTaskFunc(kernel_node);
}
template <typename T>
void BroadcastToCpuKernelMod<T>::CheckArgs() {
void BroadcastToCpuKernelMod::CheckArgs() {
size_t input_shape_size = input_shape_.size();
size_t output_shape_size = output_shape_.size();
if (output_shape_size < input_shape_size) {
@ -71,8 +109,8 @@ void BroadcastToCpuKernelMod<T>::CheckArgs() {
}
template <typename T>
bool BroadcastToCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool BroadcastToCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBroadcastToOutputsNum, kernel_name_);
CheckArgs();
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
@ -99,5 +137,22 @@ bool BroadcastToCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, c
}
return true;
}
std::vector<KernelAttr> BroadcastToCpuKernelMod::GetOpSupport() {
auto iter = func_list_.find(kernel_type_);
if (iter == func_list_.end()) {
MS_LOG(EXCEPTION) << "Does not support " << kernel_type_ << "!";
}
std::vector<KernelAttr> support_list;
std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BroadcastToFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, BroadcastTo,
[]() { return std::make_shared<BroadcastToCpuKernelMod>(kBroadcastTo); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeCpuKernelMod, DynamicBroadcastTo,
[]() { return std::make_shared<BroadcastToCpuKernelMod>(kDynamicBroadcastTo); });
} // namespace kernel
} // namespace mindspore

View File

@ -18,50 +18,54 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BROADCAST_TO_CPU_KERNEL_H_
#include <vector>
#include <map>
#include <memory>
#include <utility>
#include <string>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/cpu/kernel/nnacl/base/broadcast_to.h"
namespace mindspore {
namespace kernel {
template <typename T>
constexpr auto kBroadcastTo = "BroadcastTo";
constexpr auto kDynamicBroadcastTo = "DynamicBroadcastTo";
constexpr auto kUnknown = "Unknown";
class BroadcastToCpuKernelMod : public NativeCpuKernelMod {
public:
BroadcastToCpuKernelMod() = default;
explicit BroadcastToCpuKernelMod(const std::string &kernel_type) : kernel_type_(kernel_type) {}
~BroadcastToCpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
void InitKernel(const CNodePtr &kernel_node) override;
void CheckArgs();
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using BroadcastToFunc =
std::function<bool(BroadcastToCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::map<std::string, std::vector<std::pair<KernelAttr, BroadcastToFunc>>> func_list_;
BroadcastToFunc kernel_func_;
void InitTaskFunc(const CNodePtr &kernel_node);
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
BroadcastShapeInfo shape_info_{};
std::string kernel_type_{kUnknown};
};
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastToCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastToCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(BroadcastTo, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
BroadcastToCpuKernelMod, bool);
MS_REG_CPU_KERNEL_T(
DynamicBroadcastTo,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
BroadcastToCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
DynamicBroadcastTo,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastToCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(
DynamicBroadcastTo,
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastToCpuKernelMod, bool);
} // namespace kernel
} // namespace mindspore

View File

@ -19,18 +19,30 @@
#include <cmath>
#include <map>
#include <string>
#include <utility>
#include <algorithm>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kCastInputsNum = 1;
constexpr size_t kCastOutputsNum = 1;
} // namespace
template <typename S, typename T>
class CastCpuKernelFunc : public CpuKernelFunc {
public:
CastCpuKernelFunc() = default;
~CastCpuKernelFunc() override = default;
bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
TypeId source_dtype_{kTypeUnknown};
TypeId target_dtype_{kTypeUnknown};
};
template <typename S, typename T>
void Cast(CastCpuKernelMod<S, T> *content, const S *in, T *out, size_t size) {
void Cast(CastCpuKernelFunc<S, T> *content, const S *in, T *out, size_t size) {
auto task = [&in, &out](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(in[i]);
@ -40,28 +52,186 @@ void Cast(CastCpuKernelMod<S, T> *content, const S *in, T *out, size_t size) {
}
template <typename S, typename T>
void CastCpuKernelMod<S, T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
source_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
target_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
}
template <typename S, typename T>
bool CastCpuKernelMod<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCastInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCastOutputsNum, kernel_name_);
if (outputs[0]->size == 0) {
MS_LOG(WARNING) << "For '" << kernel_name_ << "', the memory size of output should be greater than 0, but got 0.";
return true;
}
bool CastCpuKernelFunc<S, T>::RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
const auto *input = reinterpret_cast<S *>(inputs[0]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name();
Cast<S, T>(this, input, output, outputs[0]->size / sizeof(T));
return true;
}
template <typename S, typename T>
std::shared_ptr<CpuKernelFunc> CreateCastFunc() {
return std::make_shared<CastCpuKernelFunc<S, T>>();
}
using CastCpuKernelFuncCreator = std::function<std::shared_ptr<CpuKernelFunc>()>;
static std::vector<std::pair<KernelAttr, CastCpuKernelFuncCreator>> kernel_attr_lists = {
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint8_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint8_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint8_t, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<uint8_t, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<uint8_t, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<uint8_t, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<uint8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<uint8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint8_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint8_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint8_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint8_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint16_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint16_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint16_t, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<uint16_t, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<uint16_t, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<uint16_t, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<uint16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<uint16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint16_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint16_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint16_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint16_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint32_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint32_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint32_t, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<uint32_t, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<uint32_t, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<uint32_t, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<uint32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<uint32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint32_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint32_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint32_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<uint64_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<uint64_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<uint64_t, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<uint64_t, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<uint64_t, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<uint64_t, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<uint64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<uint64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<uint64_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<uint64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<uint64_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<uint64_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int8_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int8_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int8_t, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<int8_t, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<int8_t, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<int8_t, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<int8_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<int8_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<int8_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int8_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int8_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int8_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int16_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int16_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int16_t, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<int16_t, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<int16_t, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<int16_t, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<int16_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<int16_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<int16_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int16_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int16_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int16_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int32_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int32_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int32_t, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<int32_t, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<int32_t, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<int32_t, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<int32_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<int32_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<int32_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int32_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int32_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<int64_t, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<int64_t, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<int64_t, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<int64_t, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<int64_t, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<int64_t, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<int64_t, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<int64_t, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<int64_t, float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<int64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<int64_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<int64_t, bool>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<float16, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<float16, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<float16, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<float16, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<float16, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<float16, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<float16, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<float16, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<float16, float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<float16, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<float16, double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CreateCastFunc<float16, bool>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<float, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<float, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<float, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<float, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<float, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<float, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<float, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<float, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<float, float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<float, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<float, double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CreateCastFunc<float, bool>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<double, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<double, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<double, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<double, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<double, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<double, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<double, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<double, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<double, float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<double, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<double, double>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CreateCastFunc<double, bool>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CreateCastFunc<bool, uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CreateCastFunc<bool, uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CreateCastFunc<bool, uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CreateCastFunc<bool, uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CreateCastFunc<bool, int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CreateCastFunc<bool, int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CreateCastFunc<bool, int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CreateCastFunc<bool, int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CreateCastFunc<bool, float16>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CreateCastFunc<bool, float>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CreateCastFunc<bool, double>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CreateCastFunc<bool, bool>}};
} // namespace
void CastCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
source_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
target_dtype_ = AnfAlgo::GetOutputDeviceDataType(kernel_node, 0);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
std::vector<KernelAttr> support_list;
std::transform(kernel_attr_lists.begin(), kernel_attr_lists.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CastCpuKernelFuncCreator> &pair) { return pair.first; });
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
if (!is_match) {
MS_LOG(EXCEPTION) << "Cast does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = kernel_attr_lists[index].second();
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cast, CastCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -22,11 +22,10 @@
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename S, typename T>
class CastCpuKernelMod : public NativeCpuKernelMod {
public:
CastCpuKernelMod() = default;
@ -35,169 +34,24 @@ class CastCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
const size_t kCastInputsNum = 1;
const size_t kCastOutputsNum = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCastInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCastOutputsNum, kernel_name_);
if (outputs[0]->size == 0) {
MS_LOG(WARNING) << "For '" << kernel_name_ << "', the memory size of output should be greater than 0, but got 0.";
return true;
}
return kernel_func_->RunFunc(inputs, workspace, outputs);
}
private:
TypeId source_dtype_{kTypeUnknown};
TypeId target_dtype_{kTypeUnknown};
std::shared_ptr<CpuKernelFunc> kernel_func_;
};
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint8_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint16_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint32_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, uint64_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int8_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int16_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int32_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, int64_t, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float16, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, float, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, double, bool);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr(), CastCpuKernelMod, bool, bool);
} // namespace kernel
} // namespace mindspore

View File

@ -14,9 +14,10 @@
* limitations under the License.
*/
#include <functional>
#include "plugin/device/cpu/kernel/check_valid_cpu_kernel.h"
#include <functional>
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
@ -29,20 +30,26 @@ constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
} // namespace
template <typename T>
void CheckValidCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void CheckValidCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
anchor_box_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
img_metas_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "CheckValid does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool CheckValidCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CheckParams(inputs, outputs);
bool CheckValidCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CheckParams<T>(inputs, outputs);
auto anchor_box = reinterpret_cast<T *>(inputs[0]->addr);
auto img_metas = reinterpret_cast<T *>(inputs[1]->addr);
auto output = reinterpret_cast<bool *>(outputs[0]->addr);
@ -77,8 +84,8 @@ bool CheckValidCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &in
}
template <typename T>
void CheckValidCpuKernelMod<T>::CheckParams(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
void CheckValidCpuKernelMod::CheckParams(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
// inputs: anchor_box, img_metas
if (inputs.size() != kInputSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be " << kInputSize << ", but got "
@ -96,5 +103,24 @@ void CheckValidCpuKernelMod<T>::CheckParams(const std::vector<AddressPtr> &input
<< Vector2Str(output_shape_) << ", the shape of 'img_metas': " << Vector2Str(img_metas_shape_);
}
}
std::vector<std::pair<KernelAttr, CheckValidCpuKernelMod::CheckValidFunc>> CheckValidCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
&CheckValidCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
&CheckValidCpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
&CheckValidCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
&CheckValidCpuKernelMod::LaunchKernel<uint8_t>}};
std::vector<KernelAttr> CheckValidCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CheckValidFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CheckValid, CheckValidCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -18,14 +18,14 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHECK_VALID_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
constexpr size_t COORDINATE = 4;
template <typename T>
class CheckValidCpuKernelMod : public NativeCpuKernelMod {
public:
CheckValidCpuKernelMod() = default;
@ -33,33 +33,29 @@ class CheckValidCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
void CheckParams(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using CheckValidFunc =
std::function<bool(CheckValidCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, CheckValidFunc>> func_list_;
CheckValidFunc kernel_func_;
std::vector<size_t> anchor_box_shape_;
std::vector<size_t> img_metas_shape_;
std::vector<size_t> output_shape_;
};
MS_REG_CPU_KERNEL_T(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
CheckValidCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
CheckValidCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(
CheckValid, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
CheckValidCpuKernelMod, int16_t);
MS_REG_CPU_KERNEL_T(
CheckValid, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
CheckValidCpuKernelMod, uint8_t);
} // namespace kernel
} // namespace mindspore

View File

@ -14,15 +14,46 @@
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/cholesky_inverse_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kDimNum = 2;
}
using CholeskyInverseFunc = std::function<bool(const CNodePtr &, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
template <typename T>
void CholeskyInverseCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
bool CholeskyInverseKernelFunc(const CNodePtr &node_wpt, const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_x0 = reinterpret_cast<T *>(inputs[0]->addr);
auto output_y = reinterpret_cast<T *>(outputs[0]->addr);
auto inputShape = AnfAlgo::GetInputDeviceShape(node_wpt, 0);
int64_t n = SizeToLong(inputShape[0]);
using MatrixXd = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
Eigen::Map<MatrixXd> A(input_x0, n, n);
MatrixXd result;
auto upper = common::AnfAlgo::GetNodeAttr<bool>(node_wpt, "upper");
if (upper) {
result = (A.transpose() * A).inverse();
} else {
result = (A * A.transpose()).inverse();
}
for (int64_t i = 0; i < n; i++) {
for (int64_t j = 0; j < n; j++) {
*(output_y + i * n + j) = result(i, j);
}
}
return true;
}
static std::vector<std::pair<KernelAttr, CholeskyInverseFunc>> kernel_attr_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CholeskyInverseKernelFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CholeskyInverseKernelFunc<double>}};
} // namespace
void CholeskyInverseCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
node_wpt_ = kernel_node;
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
@ -39,31 +70,24 @@ void CholeskyInverseCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
<< "while row is " << x_shape[x_shape.size() - kDimNum] << ", col is "
<< x_shape[x_shape.size() - 1];
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Cholesky inverse valid cpu kernel does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = std::bind(kernel_attr_list[index].second, node_wpt_, std::placeholders::_1, std::placeholders::_2);
}
template <typename T>
bool CholeskyInverseCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_x0 = reinterpret_cast<T *>(inputs[0]->addr);
auto output_y = reinterpret_cast<T *>(outputs[0]->addr);
auto inputShape = AnfAlgo::GetInputDeviceShape(node_wpt_, 0);
int64_t n = SizeToLong(inputShape[0]);
typedef Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> MatrixXd;
Eigen::Map<MatrixXd> A(input_x0, n, n);
MatrixXd result;
auto upper = common::AnfAlgo::GetNodeAttr<bool>(node_wpt_, "upper");
if (upper) {
result = (A.transpose() * A).inverse();
} else {
result = (A * A.transpose()).inverse();
}
for (int64_t i = 0; i < n; i++) {
for (int64_t j = 0; j < n; j++) {
*(output_y + i * n + j) = result(i, j);
}
}
return true;
std::vector<KernelAttr> CholeskyInverseCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(kernel_attr_list.begin(), kernel_attr_list.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CholeskyInverseFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CholeskyInverse, CholeskyInverseCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -18,13 +18,12 @@
#include <Eigen/Dense>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
constexpr size_t kInputNum = 1;
constexpr size_t kOutputNum = 1;
namespace kernel {
template <typename T>
class CholeskyInverseCpuKernelMod : public NativeCpuKernelMod {
public:
CholeskyInverseCpuKernelMod() = default;
@ -32,16 +31,18 @@ class CholeskyInverseCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
std::function<bool(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)> kernel_func_;
CNodePtr node_wpt_;
};
MS_REG_CPU_KERNEL_T(CholeskyInverse, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskyInverseCpuKernelMod, float)
MS_REG_CPU_KERNEL_T(CholeskyInverse, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CholeskyInverseCpuKernelMod, double)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKYINVERSE_CPU_KERNEL_H_

View File

@ -153,5 +153,26 @@ void CoalesceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
y_shape_addr[i] = x_shape_addr[i];
}
}
std::vector<KernelAttr> CoalesceCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64),
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt64)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Coalesce, CoalesceCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -27,7 +27,7 @@
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -39,6 +39,9 @@ class CoalesceCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
@ -50,26 +53,6 @@ class CoalesceCpuKernelMod : public NativeCpuKernelMod {
size_t jump = 0;
CNodeWeakPtr node_wpt_;
};
MS_REG_CPU_KERNEL(Coalesce,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt64),
CoalesceCpuKernelMod);
MS_REG_CPU_KERNEL(Coalesce,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeInt64),
CoalesceCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_COALESCE_CPU_KERNEL_H_

View File

@ -15,6 +15,8 @@
*/
#include "plugin/device/cpu/kernel/concat_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
@ -22,9 +24,7 @@ namespace kernel {
namespace {
constexpr size_t kConcatOutputsNum = 1;
} // namespace
template <typename T>
void ConcatCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void ConcatCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
cnode_ptr_ = kernel_node;
@ -33,12 +33,19 @@ void ConcatCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(input_1_shape.size());
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Concat does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool ConcatCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ConcatCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto node_ = cnode_ptr_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', cnode_ptr_(kernel_node) is expired. Error no: " << node_;
@ -88,5 +95,39 @@ bool ConcatCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs
ParallelLaunchAutoSearch(task, before_axis, this, &parallel_search_info_);
return true;
}
std::vector<std::pair<KernelAttr, ConcatCpuKernelMod::ConcatFunc>> ConcatCpuKernelMod::func_list_ = {
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&ConcatCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&ConcatCpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&ConcatCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&ConcatCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&ConcatCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&ConcatCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&ConcatCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&ConcatCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&ConcatCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&ConcatCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&ConcatCpuKernelMod::LaunchKernel<bool>}};
std::vector<KernelAttr> ConcatCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ConcatFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Concat, ConcatCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,13 +19,13 @@
#include <vector>
#include <memory>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class ConcatCpuKernelMod : public NativeCpuKernelMod {
public:
ConcatCpuKernelMod() = default;
@ -34,23 +34,22 @@ class ConcatCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using ConcatFunc = std::function<bool(ConcatCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, ConcatFunc>> func_list_;
ConcatFunc kernel_func_;
int axis_{0};
};
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, float)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, double)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, int8_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, int16_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, int32_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, int64_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, uint8_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, uint16_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, uint32_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, uint64_t)
MS_REG_CPU_KERNEL_T(Concat, KernelAttr(), ConcatCpuKernelMod, bool)
} // namespace kernel
} // namespace mindspore

View File

@ -15,6 +15,8 @@
*/
#include "plugin/device/cpu/kernel/concat_offset_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
@ -23,9 +25,7 @@ namespace {
constexpr size_t kConcatOffsetOutputNum = 1;
constexpr size_t kConcatOffsetOutputShapeSize = 2;
} // namespace
template <typename T>
void ConcatOffsetCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void ConcatOffsetCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
cnode_ptr_ = kernel_node;
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
@ -42,12 +42,18 @@ void ConcatOffsetCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
<< "', the 'axis' should be less than the dimension of 'input_x', but got 'axis': " << axis_
<< ", and the dimension of 'input_x': " << input_1_shape.size();
}
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Concat offset does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool ConcatOffsetCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ConcatOffsetCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kConcatOffsetOutputNum, kernel_name_);
auto node_ = cnode_ptr_.lock();
if (!node_) {
@ -94,5 +100,37 @@ bool ConcatOffsetCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &
}
return true;
}
std::vector<std::pair<KernelAttr, ConcatOffsetCpuKernelMod::ConcatOffsetFunc>> ConcatOffsetCpuKernelMod::func_list_ = {
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64),
&ConcatOffsetCpuKernelMod::LaunchKernel<bool>}}; // namespace kernel
std::vector<KernelAttr> ConcatOffsetCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, ConcatOffsetFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, ConcatOffset, ConcatOffsetCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,12 +19,12 @@
#include <vector>
#include <memory>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class ConcatOffsetCpuKernelMod : public NativeCpuKernelMod {
public:
ConcatOffsetCpuKernelMod() = default;
@ -33,22 +33,22 @@ class ConcatOffsetCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using ConcatOffsetFunc = std::function<bool(ConcatOffsetCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, ConcatOffsetFunc>> func_list_;
ConcatOffsetFunc kernel_func_;
size_t axis_{0};
};
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, int8_t)
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, int16_t)
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, int32_t)
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, int64_t)
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, uint8_t)
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, uint16_t)
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, uint32_t)
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, uint64_t)
MS_REG_CPU_KERNEL_T(ConcatOffset, KernelAttr(), ConcatOffsetCpuKernelMod, bool)
} // namespace kernel
} // namespace mindspore

View File

@ -19,6 +19,8 @@
#include <algorithm>
#include <utility>
#include <cmath>
#include <map>
#include <set>
#include "utils/profile.h"
#include "runtime/graph_scheduler/actor/actor_common.h"
@ -82,6 +84,125 @@ void NativeCpuKernelMod::Init(const CNodePtr &kernel_node) {
InitInputOutputSize(kernel_node);
}
std::vector<TypeId> NativeCpuKernelMod::GetInputDtypes(const CNodePtr &kernel_node) {
std::vector<TypeId> input_types;
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
input_types.emplace_back(dtype);
}
return input_types;
}
std::vector<std::string> NativeCpuKernelMod::GetInputFormats(const CNodePtr &kernel_node) {
std::vector<std::string> input_formats;
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
input_formats.emplace_back(kOpFormat_DEFAULT);
}
return input_formats;
}
std::vector<TypeId> NativeCpuKernelMod::GetOutputDtypes(const CNodePtr &kernel_node) {
std::vector<TypeId> output_types;
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
auto dtype = common::AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
output_types.emplace_back(dtype);
}
return output_types;
}
std::vector<std::string> NativeCpuKernelMod::GetOutputFormats(const CNodePtr &kernel_node) {
std::vector<std::string> output_formats;
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
output_formats.emplace_back(kOpFormat_DEFAULT);
}
return output_formats;
}
std::vector<KernelAttr> NativeCpuKernelMod::GetAllSupportedList(const std::string &kernel_name) {
if (initialize_.count(kernel_name) == 0) {
std::vector<KernelAttr> kernel_attrs;
auto kernel_support = GetOpSupport();
(void)kernel_attrs.insert(kernel_attrs.end(), kernel_support.begin(), kernel_support.end());
if (kernel_attrs.empty()) {
auto oplib_support = GetSupportFromOpLib(kernel_name);
(void)kernel_attrs.insert(kernel_attrs.end(), oplib_support.begin(), oplib_support.end());
}
(void)support_map_.emplace(kernel_name, kernel_attrs);
initialize_.insert(kernel_name);
}
return support_map_[kernel_name];
}
std::vector<KernelAttr> NativeCpuKernelMod::GetSupportFromOpLib(const std::string &kernel_name) {
static std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose",
"Unpack", "AddN", "ConcatOffset", "DynamicStitch"};
auto op_info = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
if (op_info == nullptr) {
MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu. For more details, "
<< "please refer to the list of supported cpu operations at https://www.mindspore.cn.";
}
std::vector<KernelAttr> support_kernel_attrs;
auto inputs_ptr = op_info->inputs_ptr();
auto outputs_ptr = op_info->outputs_ptr();
if (outputs_ptr.empty()) {
MS_LOG(EXCEPTION) << "The output dimension of operator '" << kernel_name << "' should not be zero.";
}
auto support_size = outputs_ptr[0]->dtypes().size();
for (size_t i = 0; i < support_size; i++) {
KernelAttr kernel_attr;
for (size_t j = 0; j < inputs_ptr.size(); j++) {
auto input_dtypes = inputs_ptr[j]->dtypes();
auto input_formats = inputs_ptr[j]->formats();
(void)kernel_attr.AddInputAttr(kernel::DtypeToTypeId(input_dtypes[i]), input_formats[i]);
}
for (size_t j = 0; j < outputs_ptr.size(); j++) {
auto output_dtypes = outputs_ptr[j]->dtypes();
auto output_formats = outputs_ptr[j]->formats();
(void)kernel_attr.AddOutputAttr(kernel::DtypeToTypeId(output_dtypes[i]), output_formats[i]);
}
if (same_op_name.count(op_info->op_name()) != 0) {
(void)kernel_attr.AddAllSameAttr(true);
}
support_kernel_attrs.push_back(kernel_attr);
}
return support_kernel_attrs;
}
std::map<std::string, std::vector<KernelAttr>> NativeCpuKernelMod::support_map_{};
std::set<std::string> NativeCpuKernelMod::initialize_{};
void NativeCpuKernelMod::SetCpuRefMapToKernelInfo(const CNodePtr &apply_kernel) {
auto kernel_attrs = GetOpSupport();
if (kernel_attrs.empty()) {
return;
}
auto kernel_attr = GetKernelAttrFromNode(apply_kernel);
auto [is_match, index] = MatchKernelAttr(kernel_attr, kernel_attrs);
if (!is_match) {
MS_LOG(EXCEPTION) << common::AnfAlgo::GetCNodeName(apply_kernel)
<< " does not support this kernel data type: " << kernel_attr;
}
MS_EXCEPTION_IF_NULL(apply_kernel);
auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(kernel_build_Info);
const auto &matched_kernel_attr = kernel_attrs[index];
if (!matched_kernel_attr.GetOutInRefMap().empty()) {
kernel_info->set_ref_map(matched_kernel_attr.GetOutInRefMap());
}
}
void CPUKernelUtils::ExpandDimsTo4(std::vector<size_t> *shape) {
MS_EXCEPTION_IF_NULL(shape);
auto len = shape->size();
@ -435,5 +556,15 @@ void AxisIterator::Init(const std::vector<size_t> &input_shape, size_t axis) {
inner_size_ *= input_shape[i];
}
}
int Sign(float x) {
if (x > 0) {
return 1;
}
if (x < 0) {
return -1;
}
return 0;
}
} // namespace kernel
} // namespace mindspore

View File

@ -23,8 +23,11 @@
#include <string>
#include <thread>
#include <vector>
#include <map>
#include <set>
#include "kernel/kernel.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/cpu/kernel/cpu_kernel_mod.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
@ -145,8 +148,21 @@ class NativeCpuKernelMod : public CpuKernelMod {
ParallelSearchInfo parallel_search_info_;
static std::vector<KernelAttr> GetCpuSupportedList(const std::string &kernel_name) {
auto temp_mod = kernel::Factory<NativeCpuKernelMod>::Instance().Create(kernel_name);
if (temp_mod == nullptr) {
MS_LOG(ERROR) << "Get cpu supported list failed!";
return std::vector<KernelAttr>{};
}
return temp_mod->GetAllSupportedList(kernel_name);
}
void SetCpuRefMapToKernelInfo(const CNodePtr &apply_kernel);
protected:
virtual void InitInputOutputSize(const CNodePtr &kernel_node);
virtual std::vector<KernelAttr> GetOpSupport() { return {}; }
CNodeWeakPtr cnode_ptr_;
template <typename T>
@ -162,6 +178,28 @@ class NativeCpuKernelMod : public CpuKernelMod {
return reinterpret_cast<T *>(addr_list[index]->addr);
}
private:
std::vector<KernelAttr> GetAllSupportedList(const std::string &kernel_name);
std::vector<KernelAttr> GetSupportFromOpLib(const std::string &kernel_name);
std::vector<TypeId> GetInputDtypes(const CNodePtr &kernel_node);
std::vector<std::string> GetInputFormats(const CNodePtr &kernel_node);
std::vector<TypeId> GetOutputDtypes(const CNodePtr &kernel_node);
std::vector<std::string> GetOutputFormats(const CNodePtr &kernel_node);
static std::map<std::string, std::vector<KernelAttr>> support_map_;
static std::set<std::string> initialize_;
};
class CpuKernelFunc {
public:
CpuKernelFunc() = default;
virtual ~CpuKernelFunc() = default;
virtual bool RunFunc(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) = 0;
virtual void InitFunc(const CNodePtr &kernel_node) {}
virtual void InitInputOutputSize(const CNodePtr &kernel_node, std::vector<size_t> *input_size_list,
std::vector<size_t> *output_size_list, std::vector<size_t> *workspace_size_list) {}
ParallelSearchInfo parallel_search_info_;
};
class CPUKernelUtils {
@ -248,6 +286,8 @@ class AxisIterator {
size_t inner_size_{0};
size_t axis_offset_{0};
};
int Sign(float x);
} // namespace kernel
} // namespace mindspore

View File

@ -1,195 +0,0 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include <memory>
#include <set>
#include <string>
#include "runtime/device/kernel_info.h"
#include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
namespace mindspore {
namespace kernel {
namespace {
const std::set<std::string> same_op_name = {"Concat", "Pack", "Stack", "Split", "Transpose",
"Unpack", "AddN", "ConcatOffset", "DynamicStitch"};
}
NativeCpuKernelModFactory &NativeCpuKernelModFactory::GetInstance() {
static NativeCpuKernelModFactory instance;
return instance;
}
void NativeCpuKernelModFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr,
NativeCpuKernelModCreator &&kernel_creator) {
(void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator);
}
std::shared_ptr<NativeCpuKernelMod> NativeCpuKernelModFactory::Create(const std::string &kernel_name,
const CNodePtr &apply_kernel) {
MS_EXCEPTION_IF_NULL(apply_kernel);
auto kernel_info = dynamic_cast<device::KernelInfo *>(apply_kernel->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(kernel_build_Info);
std::pair<bool, size_t> ret_pair = CPUKernelAttrCheck(kernel_name, *kernel_build_Info);
if (ret_pair.first) {
SetRefMapToKernelInfo(kernel_name, ret_pair.second, kernel_info);
return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second();
}
return nullptr;
}
void NativeCpuKernelModFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info,
std::vector<KernelAttr> *kernel_attrs) {
MS_EXCEPTION_IF_NULL(kernel_attrs);
MS_EXCEPTION_IF_NULL(op_info);
auto inputs_ptr = op_info->inputs_ptr();
auto outputs_ptr = op_info->outputs_ptr();
if (outputs_ptr.empty()) {
MS_LOG(EXCEPTION) << "The output dimension of operator '" << op_info->op_name() << "' should not be zero.";
}
auto first_output_dtypes = outputs_ptr[0]->dtypes();
for (size_t i = 0; i < first_output_dtypes.size(); i++) {
KernelAttr kernel_attr;
for (size_t j = 0; j < inputs_ptr.size(); j++) {
auto input_dtypes = inputs_ptr[j]->dtypes();
auto input_formats = inputs_ptr[j]->formats();
(void)kernel_attr.AddInputAttr(kernel::DtypeToTypeId(input_dtypes[i]), input_formats[i]);
}
for (size_t j = 0; j < outputs_ptr.size(); j++) {
auto output_dtypes = outputs_ptr[j]->dtypes();
auto output_formats = outputs_ptr[j]->formats();
(void)kernel_attr.AddOutputAttr(kernel::DtypeToTypeId(output_dtypes[i]), output_formats[i]);
}
if (same_op_name.count(op_info->op_name()) != 0) {
(void)kernel_attr.SetAllSameAttr(true);
}
(void)kernel_attrs->emplace_back(kernel_attr);
}
}
void NativeCpuKernelModFactory::UpdateKernelAttrs(const std::string &kernel_name,
const std::vector<KernelAttr> &kernel_attrs) {
size_t attr_size = kernel_attrs.size();
std::vector<std::pair<KernelAttr, NativeCpuKernelModCreator>> attr_creators(attr_size);
auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name
<< ", only support these types: Concat, Pack, Stack, Split, Transpose, Unpack, AddN, "
"ConcatOffset or DynamicStitch currently, but got "
<< kernel_name;
}
if (attr_size <= iter->second.size()) {
for (size_t i = 0; i < attr_size; i++) {
auto creator = name_to_attr_creator_.find(kernel_name)->second[i].second;
attr_creators[i] = std::make_pair(kernel_attrs[i], creator);
}
} else {
MS_LOG(INFO) << "attr size is not equal creators size " << kernel_name << " attr_size = " << attr_size
<< " creator_size = " << iter->second.size();
auto single_creator = name_to_attr_creator_.find(kernel_name)->second[0].second;
for (size_t i = 0; i < attr_size; i++) {
attr_creators[i] = std::make_pair(kernel_attrs[i], single_creator);
}
}
name_to_attr_creator_[kernel_name] = attr_creators;
}
std::pair<bool, size_t> NativeCpuKernelModFactory::CPUKernelAttrCheck(const std::string &kernel_name,
const KernelBuildInfo &kernel_info) {
auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) {
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
return std::make_pair(false, 0);
}
if (device::cpu::IsDynamicParamKernel(kernel_name)) {
return std::make_pair(true, 0);
}
auto kernel_attrs = GetSupportedKernelAttrList(kernel_name);
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) {
MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu. For more details, "
<< "please refer to the list of supported cpu operations at https://www.mindspore.cn.";
}
kernel_attrs.clear();
SetKernelAttrs(op_info_ptr, &kernel_attrs);
kernel::NativeCpuKernelModFactory::GetInstance().UpdateKernelAttrs(kernel_name, kernel_attrs);
}
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
if (CPUKernelSingleAttrCheck(kernel_attrs[index], kernel_info)) {
return std::make_pair(true, index);
}
}
return std::make_pair(false, 0);
}
bool NativeCpuKernelModFactory::CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr,
const KernelBuildInfo &kernel_info) const {
for (size_t i = 0; i < kernel_info.GetInputNum(); ++i) {
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetInputAttr(0).first : kernel_attr.GetInputAttr(i).first;
if (kernel_info.GetInputDeviceType(i) != dtype) {
MS_LOG(DEBUG) << "input index:" << i << ", kernel info type:" << kernel_info.GetInputDeviceType(i)
<< ", register type:" << dtype;
return false;
}
}
for (size_t i = 0; i < kernel_info.GetOutputNum(); ++i) {
auto dtype = kernel_attr.GetAllSame() ? kernel_attr.GetOutputAttr(0).first : kernel_attr.GetOutputAttr(i).first;
if (kernel_info.GetOutputDeviceType(i) != dtype) {
MS_LOG(DEBUG) << "output index:" << i << ", kernel info type:" << kernel_info.GetOutputDeviceType(i)
<< ", register type:" << dtype;
return false;
}
}
return true;
}
void NativeCpuKernelModFactory::SetRefMapToKernelInfo(const std::string &kernel_name, size_t index,
device::KernelInfo *kernel_info) {
const auto &kernel_attr = (name_to_attr_creator_.find(kernel_name)->second)[index].first;
if (!kernel_attr.GetOutInRefMap().empty()) {
kernel_info->set_ref_map(kernel_attr.GetOutInRefMap());
}
}
std::vector<KernelAttr> NativeCpuKernelModFactory::GetSupportedKernelAttrList(const std::string &kernel_name) {
std::vector<KernelAttr> result;
auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) {
MS_LOG(EXCEPTION) << "Not register CPU kernel of operator: " << kernel_name;
}
auto creators = iter->second;
result.reserve(creators.size());
for (size_t index = 0; index < creators.size(); ++index) {
auto attr_creator = creators[index];
(void)result.emplace_back(attr_creator.first);
}
return result;
}
bool NativeCpuKernelModFactory::SearchRegisteredOp(const std::string &kernel_name) const {
auto iter = name_to_attr_creator_.find(kernel_name);
return iter != name_to_attr_creator_.end();
}
} // namespace kernel
} // namespace mindspore

View File

@ -1,95 +0,0 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "kernel/oplib/oplib.h"
#include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
using mindspore::device::cpu::KernelAttr;
using NativeCpuKernelModCreator = std::function<std::shared_ptr<NativeCpuKernelMod>()>;
class NativeCpuKernelModFactory {
public:
static NativeCpuKernelModFactory &GetInstance();
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr,
NativeCpuKernelModCreator &&kernel_creator);
std::shared_ptr<NativeCpuKernelMod> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
void SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info, std::vector<KernelAttr> *kernel_attrs);
void UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs);
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
bool SearchRegisteredOp(const std::string &kernel_name) const;
private:
NativeCpuKernelModFactory() = default;
~NativeCpuKernelModFactory() = default;
DISABLE_COPY_AND_ASSIGN(NativeCpuKernelModFactory)
std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info);
bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) const;
// Set output and input ref map to kernel info which will be used by graph compiler.
void SetRefMapToKernelInfo(const std::string &kernel_name, size_t index, device::KernelInfo *kernel_info);
std::map<std::string, std::vector<std::pair<KernelAttr, NativeCpuKernelModCreator>>> name_to_attr_creator_;
};
class NativeCpuKernelRegistrar {
public:
NativeCpuKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr,
NativeCpuKernelModCreator &&kernel_creator) {
NativeCpuKernelModFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator));
}
~NativeCpuKernelRegistrar() = default;
};
#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS)
#define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS)
#define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \
static_assert(std::is_base_of<NativeCpuKernelMod, OPCLASS>::value, " must be base of NativeCpuKernelMod"); \
static const NativeCpuKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \
[]() { return std::make_shared<OPCLASS>(); });
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T)
#define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T)
#define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<NativeCpuKernelMod, OPCLASS<T>>::value, " must be base of NativeCpuKernelMod"); \
static const NativeCpuKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T>>(); });
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
MS_REG_CPU_KERNEL_T_S_(__COUNTER__, OPNAME, ATTR, OPCLASS, T, S)
#define MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S) \
_MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S)
#define _MS_REG_CPU_KERNEL_T_S_(COUNT, OPNAME, ATTR, OPCLASS, T, S) \
static_assert(std::is_base_of<NativeCpuKernelMod, OPCLASS<T, S>>::value, " must be base of NativeCpuKernelMod"); \
static const NativeCpuKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_##S##_reg( \
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T, S>>(); });
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_

View File

@ -19,8 +19,7 @@
namespace mindspore {
namespace kernel {
template <typename T>
void CropAndResizeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void CropAndResizeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
@ -98,12 +97,13 @@ void CropAndResizeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
method_ = BILINEAR_V2;
}
extrapolation_value_ = common::AnfAlgo::GetNodeAttr<float>(kernel_node, "extrapolation_value");
InitFunc(kernel_node);
}
template <typename T>
bool CropAndResizeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool CropAndResizeCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto *input_image = reinterpret_cast<T *>(inputs[IMAGE]->addr);
auto *input_boxes = reinterpret_cast<float *>(inputs[BOXES]->addr);
auto *input_box_index = reinterpret_cast<int *>(inputs[BOX_INDEX]->addr);
@ -218,5 +218,153 @@ bool CropAndResizeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr>
ParallelLaunchAutoSearch(task, IntToSize(output_size_), this, &parallel_search_info_);
return true;
}
std::vector<std::pair<KernelAttr, CropAndResizeCpuKernelMod::CropAndResizeFunc>> CropAndResizeCpuKernelMod::func_list_ =
{{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<float16>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
&CropAndResizeCpuKernelMod::LaunchKernel<uint16_t>}};
void CropAndResizeCpuKernelMod::InitFunc(const CNodePtr &kernel_node) {
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Concat does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
std::vector<KernelAttr> CropAndResizeCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CropAndResizeFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CropAndResize, CropAndResizeCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -20,9 +20,10 @@
#include <vector>
#include <string>
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -40,7 +41,6 @@ constexpr size_t BOX_INDEX = 2;
constexpr size_t CROP_SIZE = 3;
constexpr size_t IMAGE_HEIGHT = 1;
constexpr size_t IMAGE_WEIGHT = 2;
template <typename T>
class CropAndResizeCpuKernelMod : public NativeCpuKernelMod {
public:
CropAndResizeCpuKernelMod() = default;
@ -49,9 +49,22 @@ class CropAndResizeCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitFunc(const CNodePtr &kernel_node);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using CropAndResizeFunc = std::function<bool(CropAndResizeCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, CropAndResizeFunc>> func_list_;
CropAndResizeFunc kernel_func_;
int method_{1};
float extrapolation_value_{0.0};
int output_size_{0};
@ -61,168 +74,6 @@ class CropAndResizeCpuKernelMod : public NativeCpuKernelMod {
int final_width_{0};
int channel_{0};
};
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, int16_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, int16_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeUInt8)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, uint16_t);
MS_REG_CPU_KERNEL_T(CropAndResize,
KernelAttr()
.AddInputAttr(kNumberTypeUInt16)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeFloat32),
CropAndResizeCpuKernelMod, uint16_t);
} // namespace kernel
} // namespace mindspore

View File

@ -170,5 +170,44 @@ bool CrossCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inpu
}
return true;
}
std::vector<KernelAttr> CrossCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cross, CrossCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -22,7 +22,7 @@
#include <string>
#include <complex>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -36,6 +36,9 @@ class CrossCpuKernelMod : public NativeCpuKernelMod {
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
std::vector<size_t> input1_shape_;
std::vector<size_t> input2_shape_;
@ -43,67 +46,6 @@ class CrossCpuKernelMod : public NativeCpuKernelMod {
int64_t dim_;
TypeId input1_dtype_;
};
MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(
Cross,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(Cross,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
CrossCpuKernelMod);
MS_REG_CPU_KERNEL(Cross,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
CrossCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -347,5 +347,26 @@ void CTCLossCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
}
}
}
std::vector<KernelAttr> CTCLossCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CTCLoss, CTCLossCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -24,7 +24,7 @@
#include <limits>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -38,6 +38,9 @@ class CTCLossCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void GenLabelWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label,
std::vector<std::vector<uint32_t>> *label_with_blank) const;
@ -68,26 +71,6 @@ class CTCLossCpuKernelMod : public NativeCpuKernelMod {
bool ctc_merge_repeated_{false};
bool ignore_longer_outputs_than_inputs_{false};
};
MS_REG_CPU_KERNEL(CTCLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
CTCLossCpuKernelMod);
MS_REG_CPU_KERNEL(CTCLoss,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CTCLossCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_

View File

@ -15,22 +15,30 @@
*/
#include "plugin/device/cpu/kernel/cummax_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void CummaxCPUKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void CummaxCPUKernelMod::InitKernel(const CNodePtr &kernel_node) {
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output1_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
output2_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 1);
dim_ = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "dim");
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Cummax does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool CummaxCPUKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool CummaxCPUKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output1_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
auto output2_data_addr = reinterpret_cast<int64_t *>(outputs[1]->addr);
@ -98,5 +106,31 @@ bool CummaxCPUKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs
}
return true;
}
std::vector<std::pair<KernelAttr, CummaxCPUKernelMod::CummaxFunc>> CummaxCPUKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
&CummaxCPUKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
&CummaxCPUKernelMod::LaunchKernel<float16>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
&CummaxCPUKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&CummaxCPUKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
&CummaxCPUKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
&CummaxCPUKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
&CummaxCPUKernelMod::LaunchKernel<uint32_t>}};
std::vector<KernelAttr> CummaxCPUKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CummaxFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cummax, CummaxCPUKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,12 +19,12 @@
#include <memory>
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class CummaxCPUKernelMod : public NativeCpuKernelMod {
public:
CummaxCPUKernelMod() = default;
@ -32,38 +32,25 @@ class CummaxCPUKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using CummaxFunc = std::function<bool(CummaxCPUKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, CummaxFunc>> func_list_;
CummaxFunc kernel_func_;
std::vector<size_t> input_shape_;
std::vector<size_t> output1_shape_;
std::vector<size_t> output2_shape_;
size_t dim_;
};
MS_REG_CPU_KERNEL_T(
Cummax,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, float);
MS_REG_CPU_KERNEL_T(
Cummax,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, float16);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, int32_t);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(
Cummax, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64),
CummaxCPUKernelMod, uint32_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMMAX_CPU_KERNEL_H_

View File

@ -240,5 +240,17 @@ void CumSumCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inp
};
ParallelLaunchAutoSearch(task, lens, this, &parallel_search_info_);
}
std::vector<KernelAttr> CumSumCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32)},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16)},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32)},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8)},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8)}};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CumSum, CumSumCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -35,6 +35,9 @@ class CumSumCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void Reshape();
@ -81,17 +84,6 @@ class CumSumCpuKernelMod : public NativeCpuKernelMod {
int axis_{0};
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CumSumCpuKernelMod);
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
CumSumCpuKernelMod);
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CumSumCpuKernelMod);
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
CumSumCpuKernelMod);
MS_REG_CPU_KERNEL(CumSum, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
CumSumCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CUMSUM_CPU_KERNEL_H_

View File

@ -25,7 +25,6 @@
#include <functional>
#include "abstract/utils.h"
#include "plugin/device/cpu/hal/device/cpu_common.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "utils/file_utils.h"
namespace mindspore {

View File

@ -21,7 +21,6 @@
#include <functional>
#include "abstract/utils.h"
#include "plugin/device/cpu/hal/device/cpu_common.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/device/cpu/kernel/custom/julia_api.h"
#include "utils/file_utils.h"

View File

@ -44,5 +44,13 @@ bool DebugCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, co
}
return true;
}
std::vector<KernelAttr> DebugCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Debug, DebugCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,9 +19,10 @@
#include <vector>
#include <memory>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -34,10 +35,10 @@ class DebugCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(Debug, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
DebugCpuKernelMod);
protected:
std::vector<KernelAttr> GetOpSupport() override;
};
} // namespace kernel
} // namespace mindspore

View File

@ -15,6 +15,8 @@
*/
#include "plugin/device/cpu/kernel/depthtospace_cpu_kernel.h"
#include <algorithm>
#include <utility>
namespace mindspore {
namespace kernel {
@ -23,19 +25,25 @@ constexpr size_t kDepthToSpaceInputsNum = 1;
constexpr size_t kDepthToSpaceOutputsNum = 1;
} // namespace
template <typename T>
void DepthToSpaceCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void DepthToSpaceCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
block_size_ = LongToSize(common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "block_size"));
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "DepthToSpace does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool DepthToSpaceCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /* workspace */,
const std::vector<kernel::AddressPtr> &outputs) {
bool DepthToSpaceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kDepthToSpaceInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDepthToSpaceOutputsNum, kernel_name_);
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
@ -76,5 +84,37 @@ bool DepthToSpaceCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &
ParallelLaunchAutoSearch(task, size, this, &parallel_search_info_);
return true;
}
std::vector<std::pair<KernelAttr, DepthToSpaceCpuKernelMod::DepthToSpaceFunc>> DepthToSpaceCpuKernelMod::func_list_ = {
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&DepthToSpaceCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&DepthToSpaceCpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&DepthToSpaceCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&DepthToSpaceCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&DepthToSpaceCpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&DepthToSpaceCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&DepthToSpaceCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&DepthToSpaceCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&DepthToSpaceCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&DepthToSpaceCpuKernelMod::LaunchKernel<uint64_t>}};
std::vector<KernelAttr> DepthToSpaceCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, DepthToSpaceFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DepthToSpace, DepthToSpaceCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -18,68 +18,39 @@
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class DepthToSpaceCpuKernelMod : public NativeCpuKernelMod {
public:
DepthToSpaceCpuKernelMod() = default;
~DepthToSpaceCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
using DepthToSpaceFunc = std::function<bool(DepthToSpaceCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, DepthToSpaceFunc>> func_list_;
DepthToSpaceFunc kernel_func_;
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
size_t block_size_{0};
};
MS_REG_CPU_KERNEL_T(
DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DepthToSpaceCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
DepthToSpaceCpuKernelMod, float16);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
DepthToSpaceCpuKernelMod, int8_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
DepthToSpaceCpuKernelMod, int16_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DepthToSpaceCpuKernelMod, int);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
DepthToSpaceCpuKernelMod, int64_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
DepthToSpaceCpuKernelMod, uint8_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
DepthToSpaceCpuKernelMod, uint16_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
DepthToSpaceCpuKernelMod, uint32_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
DepthToSpaceCpuKernelMod, uint64_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_

View File

@ -75,5 +75,7 @@ void DropoutCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
output_addr[i] = mask_addr[i] * input_addr[i] * scale;
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Dropout, DropoutCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -21,7 +21,7 @@
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -46,8 +46,6 @@ class DropoutCpuKernelMod : public NativeCpuKernelMod {
float keep_prob_{0.0};
uint64_t tensor_size_{1};
};
MS_REG_CPU_KERNEL(Dropout, KernelAttr(), DropoutCpuKernelMod);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DROPOUT_CPU_KERNEL_H_

View File

@ -101,5 +101,7 @@ void DropoutGradBwdCpuKernelMod::DropoutBackwardKernel(const std::vector<Address
DropoutGrad(input, mask, output, SizeToInt(num_count_), scale);
}
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DropoutGrad, DropoutGradBwdCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -23,7 +23,7 @@
#include <unordered_map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -44,8 +44,6 @@ class DropoutGradBwdCpuKernelMod : public NativeCpuKernelMod {
size_t num_count_{1};
TypeId dtype_{kTypeUnknown};
};
MS_REG_CPU_KERNEL(DropoutGrad, KernelAttr(), DropoutGradBwdCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -111,5 +111,17 @@ void DynamicAssignCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inpu
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', output should be a Parameter.";
}
}
std::vector<KernelAttr> DynamicAssignCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64)};
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DynamicAssign, DynamicAssignCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -22,7 +22,7 @@
#include <unordered_map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -36,6 +36,9 @@ class DynamicAssignCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
@ -45,26 +48,6 @@ class DynamicAssignCpuKernelMod : public NativeCpuKernelMod {
size_t input_x_dtype_size_{4};
CNodeWeakPtr node_wpt_;
};
MS_REG_CPU_KERNEL(
DynamicAssign,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DynamicAssignCpuKernelMod);
MS_REG_CPU_KERNEL(
DynamicAssign,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
DynamicAssignCpuKernelMod);
MS_REG_CPU_KERNEL(
DynamicAssign,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DynamicAssignCpuKernelMod);
MS_REG_CPU_KERNEL(
DynamicAssign,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
DynamicAssignCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -22,9 +22,7 @@ namespace kernel {
namespace {
constexpr size_t kDynamicShapeOutputNum = 1;
} // namespace
template <typename T>
void TensorShapeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void TensorShapeCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
cnode_ptr_ = kernel_node;
@ -34,10 +32,9 @@ void TensorShapeCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
}
}
template <typename T>
bool TensorShapeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool TensorShapeCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDynamicShapeOutputNum, kernel_name_);
auto node_ = cnode_ptr_.lock();
if (node_ == nullptr) {
@ -60,5 +57,8 @@ bool TensorShapeCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &i
}
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DynamicShape, TensorShapeCpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, TensorShape, TensorShapeCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -20,11 +20,10 @@
#include <vector>
#include <memory>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class TensorShapeCpuKernelMod : public NativeCpuKernelMod {
public:
TensorShapeCpuKernelMod() = default;
@ -35,28 +34,6 @@ class TensorShapeCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int8_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int16_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int32_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, int64_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint8_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint16_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint32_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, uint64_t)
MS_REG_CPU_KERNEL_T(DynamicShape, KernelAttr(), TensorShapeCpuKernelMod, bool)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int8_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int16_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int32_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, int64_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint8_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint16_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint32_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, uint64_t)
MS_REG_CPU_KERNEL_T(TensorShape, KernelAttr(), TensorShapeCpuKernelMod, bool)
} // namespace kernel
} // namespace mindspore

View File

@ -17,6 +17,7 @@
#include "plugin/device/cpu/kernel/dynamic_stitch_cpu_kernel.h"
#include <functional>
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
namespace mindspore {
@ -24,19 +25,13 @@ namespace kernel {
namespace {
constexpr size_t kDynamicStitchOutputNum = 1;
} // namespace
template <typename T>
void DynamicStitchCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
cnode_ptr_ = kernel_node;
}
size_t GetShapeSize(const std::vector<size_t> &shape) {
return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
}
template <typename T>
void DynamicStitchCpuKernelMod<T>::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
bool DynamicStitchCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kDynamicStitchOutputNum, kernel_name_);
auto node_ = cnode_ptr_.lock();
int first_dim_size = 0;
@ -84,14 +79,86 @@ void DynamicStitchCpuKernelMod<T>::LaunchKernel(const std::vector<kernel::Addres
}
}
}
}
template <typename T>
bool DynamicStitchCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
LaunchKernel(inputs, outputs);
return true;
}
std::vector<std::pair<KernelAttr, DynamicStitchCpuKernelMod::DynamicStitchFunc>> DynamicStitchCpuKernelMod::func_list_ =
{{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&DynamicStitchCpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt8)
.AddOutputAttr(kNumberTypeInt8),
&DynamicStitchCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt16)
.AddOutputAttr(kNumberTypeInt16),
&DynamicStitchCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&DynamicStitchCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
&DynamicStitchCpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt8)
.AddOutputAttr(kNumberTypeUInt8),
&DynamicStitchCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt16)
.AddOutputAttr(kNumberTypeUInt16),
&DynamicStitchCpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt32)
.AddOutputAttr(kNumberTypeUInt32),
&DynamicStitchCpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeUInt64)
.AddOutputAttr(kNumberTypeUInt64),
&DynamicStitchCpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr()
.AddAllSameAttr(true)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
&DynamicStitchCpuKernelMod::LaunchKernel<bool>}};
void DynamicStitchCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
cnode_ptr_ = kernel_node;
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, DynamicStitchFunc> &pair) { return pair.first; });
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
if (!is_match) {
MS_LOG(EXCEPTION) << "DynamicStitch does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, DynamicStitch, DynamicStitchCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,12 +19,13 @@
#include <vector>
#include <memory>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class DynamicStitchCpuKernelMod : public NativeCpuKernelMod {
public:
DynamicStitchCpuKernelMod() = default;
@ -32,26 +33,20 @@ class DynamicStitchCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, outputs);
}
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using DynamicStitchFunc = std::function<bool(DynamicStitchCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, DynamicStitchFunc>> func_list_;
DynamicStitchFunc kernel_func_;
size_t input_tuple_num_{1};
};
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, int8_t)
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, int16_t)
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, int32_t)
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, int64_t)
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, uint8_t)
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, uint16_t)
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, uint32_t)
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, uint64_t)
MS_REG_CPU_KERNEL_T(DynamicStitch, KernelAttr(), DynamicStitchCpuKernelMod, bool)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DYNAMIC_STITCH_CPU_KERNEL_H_

View File

@ -15,7 +15,9 @@
*/
#include "plugin/device/cpu/kernel/eigen/cholesky_cpu_kernel.h"
#include <algorithm>
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
#include "utils/ms_utils.h"
#include "Eigen/Cholesky"
@ -30,8 +32,7 @@ constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
} // namespace
template <typename T>
void CholeskyCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
void CholeskyCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
if (shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << " input or output shape is empty which is invalid.";
}
@ -52,8 +53,7 @@ void CholeskyCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, s
outer_batch_ /= ((*row) * (*col));
}
template <typename T>
void CholeskyCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void CholeskyCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
@ -72,11 +72,18 @@ void CholeskyCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
if (common::AnfAlgo::HasNodeAttr(LOWER, kernel_node)) {
lower_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Cholesky does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool CholeskyCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
T *batch_input_value = reinterpret_cast<T *>(inputs[kInputIndex]->addr);
T *batch_output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
Eigen::LLT<Matrix<T, RowMajor>> llt;
@ -102,5 +109,20 @@ bool CholeskyCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, cons
}
return true;
}
std::vector<std::pair<KernelAttr, CholeskyCpuKernelMod::CholeskyFunc>> CholeskyCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&CholeskyCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&CholeskyCpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> CholeskyCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CholeskyFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Cholesky, CholeskyCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -17,12 +17,12 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class CholeskyCpuKernelMod : public NativeCpuKernelMod {
public:
CholeskyCpuKernelMod() = default;
@ -30,10 +30,25 @@ class CholeskyCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using CholeskyFunc =
std::function<bool(CholeskyCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, CholeskyFunc>> func_list_;
CholeskyFunc kernel_func_;
bool lower_{true};
bool clean_{true};
size_t outer_batch_{1};
@ -43,12 +58,6 @@ class CholeskyCpuKernelMod : public NativeCpuKernelMod {
size_t output_col_{1};
TypeId dtype_{kNumberTypeFloat32};
};
MS_REG_CPU_KERNEL_T(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskyCpuKernelMod, float)
MS_REG_CPU_KERNEL_T(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CholeskyCpuKernelMod, double)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_

View File

@ -15,6 +15,8 @@
*/
#include "plugin/device/cpu/kernel/eigen/cholesky_solve_cpu_kernel.h"
#include <algorithm>
#include <utility>
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
#include "Eigen/Cholesky"
namespace mindspore {
@ -29,8 +31,7 @@ constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
} // namespace
template <typename T>
void CholeskySolveCpuKernelMod<T>::InitRightMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
void CholeskySolveCpuKernelMod::InitRightMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
if (shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << " input shape is empty which is invalid.";
}
@ -47,9 +48,8 @@ void CholeskySolveCpuKernelMod<T>::InitRightMatrixInfo(const std::vector<size_t>
outer_batch_ /= ((*row) * (*col));
}
template <typename T>
void CholeskySolveCpuKernelMod<T>::InitLeftMatrixInfo(const std::vector<size_t> &shape, const bool is_rank_equal,
size_t *row, size_t *col) {
void CholeskySolveCpuKernelMod::InitLeftMatrixInfo(const std::vector<size_t> &shape, const bool is_rank_equal,
size_t *row, size_t *col) {
if (shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << " input or output shape is empty which is invalid.";
}
@ -62,8 +62,7 @@ void CholeskySolveCpuKernelMod<T>::InitLeftMatrixInfo(const std::vector<size_t>
}
}
template <typename T>
void CholeskySolveCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void CholeskySolveCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
@ -85,11 +84,18 @@ void CholeskySolveCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
MS_LOG_EXCEPTION << kernel_name_ << " llt solve input a row is not match to b row: " << input_a_row_ << " vs "
<< input_b_row_;
}
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "CholeskySolve does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T>
bool CholeskySolveCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
bool CholeskySolveCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
T *batch_input_value = reinterpret_cast<T *>(inputs[kInputAIndex]->addr);
T *batch_input_b_value = reinterpret_cast<T *>(inputs[kInputBIndex]->addr);
T *batch_output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
@ -110,5 +116,20 @@ bool CholeskySolveCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs,
}
return true;
}
std::vector<std::pair<KernelAttr, CholeskySolveCpuKernelMod::CholeskySolveFunc>> CholeskySolveCpuKernelMod::func_list_ =
{{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&CholeskySolveCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&CholeskySolveCpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> CholeskySolveCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CholeskySolveFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CholeskySolve, CholeskySolveCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -17,12 +17,12 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_SOLVE_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_SOLVE_KERNEL_H_
#include <vector>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class CholeskySolveCpuKernelMod : public NativeCpuKernelMod {
public:
CholeskySolveCpuKernelMod() = default;
@ -30,12 +30,26 @@ class CholeskySolveCpuKernelMod : public NativeCpuKernelMod {
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitLeftMatrixInfo(const std::vector<size_t> &shape, const bool is_rank_equal, size_t *row, size_t *col);
void InitRightMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using CholeskySolveFunc =
std::function<bool(CholeskySolveCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, CholeskySolveFunc>> func_list_;
CholeskySolveFunc kernel_func_;
size_t outer_batch_{1};
size_t input_a_row_{1};
size_t input_a_col_{1};
@ -46,15 +60,6 @@ class CholeskySolveCpuKernelMod : public NativeCpuKernelMod {
TypeId dtype_{kNumberTypeFloat32};
bool lower_{false};
};
MS_REG_CPU_KERNEL_T(
CholeskySolve,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskySolveCpuKernelMod, float)
MS_REG_CPU_KERNEL_T(
CholeskySolve,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CholeskySolveCpuKernelMod, double)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_SOLVE_KERNEL_H_

View File

@ -15,7 +15,9 @@
*/
#include "plugin/device/cpu/kernel/eigen/eig_cpu_kernel.h"
#include <algorithm>
#include <type_traits>
#include <utility>
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
#include "utils/ms_utils.h"
#include "Eigen/Eigenvalues"
@ -28,8 +30,7 @@ constexpr size_t kOutputsNumNV = 1;
constexpr size_t kOutputsNumV = 2;
} // namespace
template <typename T, typename C>
void EigCpuKernelMod<T, C>::InitMatrixInfo(const std::vector<size_t> &shape) {
void EigCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape) {
if (shape.size() < kShape2dDims) {
MS_LOG_EXCEPTION << "For '" << kernel_name_ << "', the rank of parameter 'a' must be at least 2, but got "
<< shape.size() << " dimensions.";
@ -48,8 +49,7 @@ void EigCpuKernelMod<T, C>::InitMatrixInfo(const std::vector<size_t> &shape) {
batch_size_ /= (row_size_ * col_size_);
}
template <typename T, typename C>
void EigCpuKernelMod<T, C>::InitKernel(const CNodePtr &kernel_node) {
void EigCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
// If compute_v_ is true, then: w, v = Eig(a)
// If compute_v_ is false, then: w = Eig(a)
@ -63,10 +63,17 @@ void EigCpuKernelMod<T, C>::InitKernel(const CNodePtr &kernel_node) {
CHECK_KERNEL_OUTPUTS_NUM(output_num, expect_output_num, kernel_name_);
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
InitMatrixInfo(input_shape);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Eig does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = func_list_[index].second;
}
template <typename T, typename C>
bool EigCpuKernelMod<T, C>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool EigCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_w_addr = reinterpret_cast<C *>(outputs[0]->addr);
@ -92,5 +99,46 @@ bool EigCpuKernelMod<T, C>::Launch(const std::vector<AddressPtr> &inputs, const
}
return true;
}
std::vector<std::pair<KernelAttr, EigCpuKernelMod::EigFunc>> EigCpuKernelMod::func_list_ = {
// If compute_v is false.
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
&EigCpuKernelMod::LaunchKernel<float, float_complex>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
&EigCpuKernelMod::LaunchKernel<double, double_complex>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&EigCpuKernelMod::LaunchKernel<float_complex, float_complex>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&EigCpuKernelMod::LaunchKernel<double_complex, double_complex>},
// If compute_v is true.
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&EigCpuKernelMod::LaunchKernel<float, float_complex>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&EigCpuKernelMod::LaunchKernel<double, double_complex>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&EigCpuKernelMod::LaunchKernel<float_complex, float_complex>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&EigCpuKernelMod::LaunchKernel<double_complex, double_complex>}};
std::vector<KernelAttr> EigCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, EigFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Eig, EigCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,8 +19,9 @@
#include <vector>
#include <complex>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
@ -28,60 +29,40 @@ namespace kernel {
using float_complex = std::complex<float>;
using double_complex = std::complex<double>;
/**
* this is for Generic matrix eigenvalues and eigenvectors
* @tparam T , input Type
* @tparam C , output Type, complex
*/
template <typename T, typename C>
class EigCpuKernelMod : public NativeCpuKernelMod {
public:
EigCpuKernelMod() = default;
~EigCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
void InitMatrixInfo(const std::vector<size_t> &shape);
/**
* this is for Generic matrix eigenvalues and eigenvectors
* @tparam T , input Type
* @tparam C , output Type, complex
*/
template <typename T, typename S>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using EigFunc = std::function<bool(EigCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, EigFunc>> func_list_;
EigFunc kernel_func_;
bool compute_v_{true};
size_t row_size_{1};
size_t col_size_{1};
size_t batch_size_{1};
};
// If compute_v is false.
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
EigCpuKernelMod, float, float_complex);
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
EigCpuKernelMod, double, double_complex);
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
EigCpuKernelMod, float_complex, float_complex);
MS_REG_CPU_KERNEL_T_S(Eig, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
EigCpuKernelMod, double_complex, double_complex);
// If compute_v is true.
MS_REG_CPU_KERNEL_T_S(
Eig,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
EigCpuKernelMod, float, float_complex);
MS_REG_CPU_KERNEL_T_S(Eig,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EigCpuKernelMod, double, double_complex);
MS_REG_CPU_KERNEL_T_S(Eig,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EigCpuKernelMod, float_complex, float_complex);
MS_REG_CPU_KERNEL_T_S(Eig,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EigCpuKernelMod, double_complex, double_complex);
} // namespace kernel
} // namespace mindspore

View File

@ -15,6 +15,8 @@
*/
#include "plugin/device/cpu/kernel/eigen/eigh_cpu_kernel.h"
#include <algorithm>
#include <tuple>
#include <type_traits>
#include "plugin/device/cpu/kernel/eigen/eigen_common_utils.h"
#include "Eigen/Eigenvalues"
@ -27,8 +29,7 @@ constexpr size_t kInputsNum = 1;
constexpr size_t kOutputsNum = 2;
} // namespace
template <typename T>
void EighCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void EighCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
compute_eigen_vectors_ = common::AnfAlgo::GetNodeAttr<bool>(kernel_node, C_EIEH_VECTOR);
@ -44,6 +45,15 @@ void EighCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
<< A_shape[kDim1] << "].";
}
m_ = A_shape[kDim0];
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "Eigh does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = std::get<1>(func_list_[index]);
const size_t kTwoIdx = 2;
init_io_func_ = std::get<kTwoIdx>(func_list_[index]);
}
template <typename T>
@ -70,14 +80,14 @@ void SolveComplexMatrix(const Map<MatrixSquare<T>> &A, Map<MatrixSquare<T>> *out
}
template <typename T>
void EighCpuKernelMod<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
void EighCpuKernelMod::InitIOFunc(const CNodePtr &kernel_node) {
NativeCpuKernelMod::InitInputOutputSize(kernel_node);
(void)workspace_size_list_.emplace_back(m_ * m_ * sizeof(T));
}
template <typename T>
bool EighCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
bool EighCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kInputsNum, kernel_name_);
auto A_addr = reinterpret_cast<T *>(inputs[0]->addr);
// is the Matrix a symmetric matrix(true lower triangle, false upper triangle)
@ -103,5 +113,39 @@ bool EighCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, const st
}
return true;
}
std::vector<std::tuple<KernelAttr, EighCpuKernelMod::EighFunc, EighCpuKernelMod::EighInitFunc>>
EighCpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EighCpuKernelMod::LaunchKernel<float>, &EighCpuKernelMod::InitIOFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&EighCpuKernelMod::LaunchKernel<double>, &EighCpuKernelMod::InitIOFunc<double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&EighCpuKernelMod::LaunchKernel<float_complex>, &EighCpuKernelMod::InitIOFunc<float_complex>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&EighCpuKernelMod::LaunchKernel<double_complex>, &EighCpuKernelMod::InitIOFunc<double_complex>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&EighCpuKernelMod::LaunchKernel<float>, &EighCpuKernelMod::InitIOFunc<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&EighCpuKernelMod::LaunchKernel<double>, &EighCpuKernelMod::InitIOFunc<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
&EighCpuKernelMod::LaunchKernel<float_complex>, &EighCpuKernelMod::InitIOFunc<float_complex>},
{KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
&EighCpuKernelMod::LaunchKernel<double_complex>, &EighCpuKernelMod::InitIOFunc<double_complex>}};
std::vector<KernelAttr> EighCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::tuple<KernelAttr, EighFunc, EighInitFunc> &item) { return std::get<0>(item); });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Eigh, EighCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,67 +19,52 @@
#include <vector>
#include <complex>
#include <tuple>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
using float_complex = std::complex<float>;
using double_complex = std::complex<double>;
/**
* this is for Symmetric matrix eigenvalues & eigenvectors, can decompress the lower/upper triangle matrix
* @tparam T , input Type
* @tparam C , output Type, complex
*/
template <typename T>
class EighCpuKernelMod : public NativeCpuKernelMod {
public:
EighCpuKernelMod() = default;
~EighCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
void InitInputOutputSize(const CNodePtr &kernel_node) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
/**
* this is for Symmetric matrix eigenvalues & eigenvectors, can decompress the lower/upper triangle matrix
* @tparam T , input Type
*/
template <typename T>
void InitIOFunc(const CNodePtr &kernel_node);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using EighFunc =
std::function<bool(EighCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
using EighInitFunc = std::function<void(EighCpuKernelMod *, const CNodePtr &)>;
static std::vector<std::tuple<KernelAttr, EighFunc, EighInitFunc>> func_list_;
EighFunc kernel_func_;
EighInitFunc init_io_func_;
size_t m_{1};
bool compute_eigen_vectors_{false};
bool lower_{true};
TypeId dtype_{kNumberTypeFloat32};
};
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EighCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EighCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
EighCpuKernelMod, float_complex);
MS_REG_CPU_KERNEL_T(Eigh, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
EighCpuKernelMod, double_complex);
MS_REG_CPU_KERNEL_T(
Eigh,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EighCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(
Eigh,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
EighCpuKernelMod, double);
MS_REG_CPU_KERNEL_T(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
EighCpuKernelMod, float_complex);
MS_REG_CPU_KERNEL_T(Eigh,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
EighCpuKernelMod, double_complex);
} // namespace kernel
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <algorithm>
#include <utility>
#include <unordered_map>
#include <tuple>
#include "utils/ms_utils.h"
namespace mindspore {
@ -35,8 +36,7 @@ constexpr size_t kColIndex = 1;
constexpr int kZeroThreshold = INT32_MIN;
} // namespace
template <typename T>
void LUCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
void LUCpuKernelMod::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
constexpr size_t lu_min_dim = 1;
if (shape.size() <= lu_min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is " << shape.size() << " which is invalid.";
@ -50,8 +50,7 @@ void LUCpuKernelMod<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t
}
}
template <typename T>
void LUCpuKernelMod<T>::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
void LUCpuKernelMod::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
constexpr size_t pivot_min_dim = 1;
if (shape.size() < pivot_min_dim) {
MS_LOG_EXCEPTION << kernel_name_ << "pivots shape is " << shape.size() << " which is invalid.";
@ -64,8 +63,7 @@ void LUCpuKernelMod<T>::InitPivotVecInfo(const std::vector<size_t> &shape, size_
}
}
template <typename T>
void LUCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
void LUCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
@ -81,10 +79,19 @@ void LUCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) {
InitMatrixInfo(permutation_shape, &permutation_row_, &permutation_col_);
auto pivots_shape = common::AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex);
InitPivotVecInfo(pivots_shape, &pivots_row_, &pivots_col_);
auto kernel_attr = GetKernelAttrFromNode(kernel_node);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(EXCEPTION) << "LU does not support this kernel data type: " << kernel_attr;
}
kernel_func_ = std::get<1>(func_list_[index]);
const size_t kTwoIdx = 2;
init_io_func_ = std::get<kTwoIdx>(func_list_[index]);
}
template <typename T>
void LUCpuKernelMod<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
void LUCpuKernelMod::InitIOSize(const CNodePtr &kernel_node) {
NativeCpuKernelMod::InitInputOutputSize(kernel_node);
size_t lu_size = lu_col_ * sizeof(T);
(void)workspace_size_list_.emplace_back(lu_size);
@ -92,14 +99,14 @@ void LUCpuKernelMod<T>::InitInputOutputSize(const CNodePtr &kernel_node) {
}
template <typename T>
T LUCpuKernelMod<T>::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) {
T LUCpuKernelMod::GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j) {
const T *pered_lu_value = lu_value + per_value[i] * lu_col_ + j;
return *pered_lu_value;
}
template <typename T>
bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k,
size_t rows) {
bool LUCpuKernelMod::UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k,
size_t rows) {
T max_major_value = static_cast<T>(kZeroThreshold);
size_t max_major_index = 0;
for (size_t i = k; i < rows; ++i) {
@ -118,16 +125,16 @@ bool LUCpuKernelMod<T>::UpdateMajorPermutation(T *lu_value, std::vector<int> *pe
}
template <typename T>
void LUCpuKernelMod<T>::SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j,
const T &value) {
void LUCpuKernelMod::SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j,
const T &value) {
T *per_lu_value = lu_value + per_value[i] * lu_col_ + j;
*per_lu_value = value;
}
template <typename T>
bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
bool LUCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
// input matrix of (m,n) PA = LU
T *batch_a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
T *batch_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
@ -233,5 +240,28 @@ bool LUCpuKernelMod<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
return true;
}
std::vector<std::tuple<KernelAttr, LUCpuKernelMod::LUFunc, LUCpuKernelMod::InitFunc>> LUCpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&LUCpuKernelMod::LaunchKernel<float>, &LUCpuKernelMod::InitIOSize<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&LUCpuKernelMod::LaunchKernel<double>, &LUCpuKernelMod::InitIOSize<double>}};
std::vector<KernelAttr> LUCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::tuple<KernelAttr, LUFunc, InitFunc> &tuple_item) { return std::get<0>(tuple_item); });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, LU, LUCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -17,28 +17,49 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_
#include <tuple>
#include <vector>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class LUCpuKernelMod : public NativeCpuKernelMod {
public:
LUCpuKernelMod() = default;
~LUCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
const std::vector<AddressPtr> &outputs) override {
return kernel_func_(this, inputs, workspace, outputs);
}
protected:
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
template <typename T>
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows);
template <typename T>
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
template <typename T>
void InitIOSize(const CNodePtr &kernel_node);
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs);
using LUFunc = std::function<bool(LUCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
using InitFunc = std::function<void(LUCpuKernelMod *, const CNodePtr &)>;
static std::vector<std::tuple<KernelAttr, LUFunc, InitFunc>> func_list_;
LUFunc kernel_func_;
InitFunc init_io_func_;
void InitInputOutputSize(const CNodePtr &kernel_node) override { init_io_func_(this, kernel_node); }
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
void InitInputOutputSize(const CNodePtr &kernel_node) override;
T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j);
bool UpdateMajorPermutation(T *lu_value, std::vector<int> *per_value, int *pivots, size_t k, size_t rows);
void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value);
size_t batch_size_{1};
size_t a_row_{1};
size_t a_col_{1};
@ -51,21 +72,6 @@ class LUCpuKernelMod : public NativeCpuKernelMod {
TypeId dtype_{kNumberTypeFloat32};
int *batch_pivots_{nullptr};
};
MS_REG_CPU_KERNEL_T(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUCpuKernelMod, float);
MS_REG_CPU_KERNEL_T(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUCpuKernelMod, double);
} // namespace kernel
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More