forked from mindspore-Ecosystem/mindspore
cpu kernel register type
This commit is contained in:
parent
e6273ce33b
commit
90fd53c1d8
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "device/cpu/kernel_select_cpu.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
|
||||
#include "kernel/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
|
||||
using mindspore::kernel::KernelBuildInfo;
|
||||
namespace {
|
||||
bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
|
||||
auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector<size_t> &input_not_cnode_indexes,
|
||||
const CNodePtr kernel_node) {
|
||||
for (auto &input_index : input_not_cnode_indexes) {
|
||||
auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first;
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
std::vector<TypeId> output_types;
|
||||
output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first);
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get());
|
||||
}
|
||||
}
|
||||
|
||||
void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *input_formats,
|
||||
std::vector<TypeId> *input_types, std::vector<size_t> *input_no_cnode_indexes) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
TypeId dtype = kTypeUnknown;
|
||||
if (IsInputNotCNode(kernel_node, input_index)) {
|
||||
input_no_cnode_indexes->emplace_back(input_index);
|
||||
} else {
|
||||
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
|
||||
}
|
||||
input_formats->emplace_back(kOpFormat_DEFAULT);
|
||||
input_types->emplace_back(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr,
|
||||
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (kernel_attr.GetOutputSize() != output_num) {
|
||||
MS_LOG(EXCEPTION) << "Output num is not equal!";
|
||||
}
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second);
|
||||
auto dtype = kernel_attr.GetOutputAttr(output_index).first;
|
||||
output_types->emplace_back(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<std::string> &input_formats,
|
||||
const std::vector<TypeId> &input_types,
|
||||
const std::vector<size_t> &input_not_cnode_indexes) {
|
||||
if (kernel_attr.GetInputSize() != input_types.size()) {
|
||||
MS_LOG(ERROR) << "Output num is not equal!";
|
||||
return false;
|
||||
}
|
||||
auto input_num = input_types.size();
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(),
|
||||
[i](size_t index) { return index == i; });
|
||||
if (is_not_cnode_idx) {
|
||||
continue;
|
||||
}
|
||||
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
|
||||
MS_LOG(ERROR) << "reg dtype=" << kernel_attr.GetInputAttr(i).first << ", input dtype=" << input_types[i];
|
||||
return false;
|
||||
}
|
||||
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
|
||||
MS_LOG(ERROR) << "reg format=" << kernel_attr.GetInputAttr(i).second << ", input format=" << input_formats[i];
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||
std::vector<std::string> input_formats;
|
||||
std::vector<TypeId> input_types;
|
||||
std::vector<size_t> input_not_cnode_indexes;
|
||||
std::vector<std::string> output_formats;
|
||||
std::vector<TypeId> output_types;
|
||||
|
||||
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
|
||||
|
||||
auto kernel_attrs =
|
||||
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
||||
|
||||
for (auto &kernel_attr : kernel_attrs) {
|
||||
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
|
||||
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
|
||||
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
|
||||
for (auto &input_index : input_not_cnode_indexes) {
|
||||
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
builder->SetInputsFormat(input_formats);
|
||||
builder->SetInputsDeviceType(input_types);
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||
}
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_
|
||||
#define MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_
|
||||
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype/type.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
|
||||
|
||||
class KernelAttr {
|
||||
public:
|
||||
using DataType = std::pair<TypeId, std::string>;
|
||||
KernelAttr() = default;
|
||||
~KernelAttr() = default;
|
||||
|
||||
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
|
||||
input_type_.emplace_back(ms_type, format);
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
|
||||
output_type_.emplace_back(ms_type, format);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
|
||||
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
|
||||
|
||||
size_t GetInputSize() const { return input_type_.size(); }
|
||||
size_t GetOutputSize() const { return output_type_.size(); }
|
||||
|
||||
private:
|
||||
std::vector<DataType> input_type_;
|
||||
std::vector<DataType> output_type_;
|
||||
};
|
||||
} // namespace cpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_
|
|
@ -33,7 +33,15 @@ class ApplyMomentumCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(ApplyMomentum, ApplyMomentumCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ApplyMomentum,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ApplyMomentumCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -37,7 +37,8 @@ class ArgmaxCPUKernel : public CPUKernel {
|
|||
size_t batch_size_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Argmax, ArgmaxCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ArgmaxCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -37,7 +37,10 @@ class BiasAddCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> bias_shape_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(BiasAdd, BiasAddCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
BiasAdd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BiasAddCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_
|
||||
|
|
|
@ -36,7 +36,8 @@ class BiasAddGradCPUKernel : public CPUKernel {
|
|||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(BiasAddGrad, BiasAddGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
BiasAddGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_
|
||||
|
|
|
@ -20,29 +20,79 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "device/kernel_info.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
CPUKernelFactory &CPUKernelFactory::Get() {
|
||||
CPUKernelFactory &CPUKernelFactory::GetInstance() {
|
||||
static CPUKernelFactory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void CPUKernelFactory::Register(const std::string &kernel_name, CPUKernelCreator &&kernel_creator) {
|
||||
if (kernel_creators_.find(kernel_name) == kernel_creators_.end()) {
|
||||
(void)kernel_creators_.emplace(kernel_name, kernel_creator);
|
||||
void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr,
|
||||
CPUKernelCreator &&kernel_creator) {
|
||||
(void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator);
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name;
|
||||
MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name) {
|
||||
auto iter = kernel_creators_.find(kernel_name);
|
||||
if (iter != kernel_creators_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return (iter->second)();
|
||||
std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) {
|
||||
auto kernel_info = apply_kernel->kernel_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_Info);
|
||||
std::pair<bool, size_t> ret_pair = CPUKernelAttrCheck(kernel_name, kernel_build_Info);
|
||||
if (ret_pair.first) {
|
||||
return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name,
|
||||
const KernelBuildInfo *kernel_info) {
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
if (iter == name_to_attr_creator_.end()) {
|
||||
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
auto creators = iter->second;
|
||||
for (size_t index = 0; index < creators.size(); ++index) {
|
||||
auto attr_creator = creators[index];
|
||||
for (size_t i = 0; i < kernel_info->GetInputNum(); ++i) {
|
||||
if (kernel_info->GetInputDeviceType(i) != attr_creator.first.GetInputAttr(i).first) {
|
||||
MS_LOG(WARNING) << "cpu kernel attr check failed. input index: " << i << ".";
|
||||
MS_LOG(WARNING) << "kernel info type:" << kernel_info->GetInputDeviceType(i) << ", "
|
||||
<< "register type:" << attr_creator.first.GetInputAttr(i).first;
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < kernel_info->GetOutputNum(); ++i) {
|
||||
if (kernel_info->GetOutputDeviceType(i) != attr_creator.first.GetOutputAttr(i).first) {
|
||||
MS_LOG(WARNING) << "cpu kernel attr check failed. output index: " << i << ".";
|
||||
MS_LOG(WARNING) << "kernel info type:" << kernel_info->GetOutputDeviceType(i) << ", "
|
||||
<< "register type:" << attr_creator.first.GetOutputAttr(i).first;
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
}
|
||||
return std::make_pair(true, index);
|
||||
}
|
||||
return std::make_pair(false, 0);
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std::string &kernel_name) {
|
||||
std::vector<KernelAttr> result;
|
||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||
if (iter == name_to_attr_creator_.end()) {
|
||||
MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
||||
return result;
|
||||
}
|
||||
auto creators = iter->second;
|
||||
for (size_t index = 0; index < creators.size(); ++index) {
|
||||
auto attr_creator = creators[index];
|
||||
result.push_back(attr_creator.first);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,35 +21,54 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "common/utils.h"
|
||||
#include "kernel/cpu/cpu_kernel.h"
|
||||
#include "device/cpu/kernel_select_cpu.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
using mindspore::device::cpu::KernelAttr;
|
||||
using CPUKernelCreator = std::function<std::shared_ptr<CPUKernel>()>;
|
||||
class CPUKernelFactory {
|
||||
public:
|
||||
static CPUKernelFactory &Get();
|
||||
void Register(const std::string &kernel_name, CPUKernelCreator &&kernel_creator);
|
||||
static CPUKernelFactory &GetInstance();
|
||||
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
|
||||
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name);
|
||||
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
|
||||
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
|
||||
|
||||
private:
|
||||
CPUKernelFactory() = default;
|
||||
~CPUKernelFactory() = default;
|
||||
DISABLE_COPY_AND_ASSIGN(CPUKernelFactory)
|
||||
std::map<std::string, CPUKernelCreator> kernel_creators_;
|
||||
std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info);
|
||||
std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_;
|
||||
};
|
||||
|
||||
class CPUKernelRegistrar {
|
||||
public:
|
||||
CPUKernelRegistrar(const std::string &kernel_name, CPUKernelCreator &&kernel_creator) {
|
||||
CPUKernelFactory::Get().Register(kernel_name, std::move(kernel_creator));
|
||||
CPUKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator) {
|
||||
CPUKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator));
|
||||
}
|
||||
~CPUKernelRegistrar() = default;
|
||||
};
|
||||
|
||||
#define MS_REG_CPU_KERNEL(KERNEL_NAME, KERNEL_CLASS) \
|
||||
static const CPUKernelRegistrar g_cpu_kernel_##KERNEL_NAME##_reg(#KERNEL_NAME, \
|
||||
[]() { return std::make_shared<KERNEL_CLASS>(); });
|
||||
#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) \
|
||||
static_assert(std::is_base_of<CPUKernel, OPCLASS>::value, " must be base of CPUKernel"); \
|
||||
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_reg(#OPNAME, ATTR, \
|
||||
[]() { return std::make_shared<OPCLASS>(); });
|
||||
|
||||
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \
|
||||
static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \
|
||||
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \
|
||||
[]() { return std::make_shared<OPCLASS<T>>(); });
|
||||
|
||||
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
|
||||
static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \
|
||||
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \
|
||||
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T, S>>(); });
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,10 @@ class EqualCountCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(EqualCount, EqualCountCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
EqualCount,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
EqualCountCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,10 @@ class Conv2dCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Conv2D, Conv2dCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Conv2D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
Conv2dCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,10 @@ class Conv2dGradFilterCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Conv2DBackpropFilter, Conv2dGradFilterCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Conv2DBackpropFilter,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
Conv2dGradFilterCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,10 @@ class Conv2dGradInputCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Conv2DBackpropInput, Conv2dGradInputCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Conv2DBackpropInput,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
Conv2dGradInputCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -40,7 +40,10 @@ class MatMulCPUKernel : public MKLCPUKernel {
|
|||
dnnl_dim_t dim_k_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MatMul, MatMulCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
MatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MatMulCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,9 @@ class MulCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Mul, MulCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MulCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,8 @@ class PoolingCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MaxPool, PoolingCPUKernel);
|
||||
MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
PoolingCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -43,7 +43,13 @@ class PoolingGradCPUKernel : public MKLCPUKernel {
|
|||
std::vector<size_t> dst_shape_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(MaxPoolGrad, PoolingGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(MaxPoolGrad,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
PoolingGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class ReluCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(ReLU, ReluCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReluCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,10 @@ class ReluGradCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(ReluGrad, ReluGradCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
ReluGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReluGradCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,7 +33,8 @@ class SoftmaxCPUKernel : public MKLCPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Softmax, SoftmaxCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SoftmaxCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -43,7 +43,10 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel {
|
|||
size_t batch_size_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(SparseSoftmaxCrossEntropyWithLogits, SparseSoftmaxCrossEntropyWithLogitsCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
SparseSoftmaxCrossEntropyWithLogits,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SparseSoftmaxCrossEntropyWithLogitsCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -38,7 +38,13 @@ class OneHotCPUKernel : public CPUKernel {
|
|||
size_t axis_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(OneHot, OneHotCPUKernel);
|
||||
MS_REG_CPU_KERNEL(OneHot,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
OneHotCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -33,9 +33,12 @@ class ReshapeCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Reshape, ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Flatten, ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ExpandDims, ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReshapeCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReshapeCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "device/kernel_runtime.h"
|
||||
#include "predict/predict.h"
|
||||
#include "kernel/cpu/cpu_kernel_factory.h"
|
||||
#include "device/cpu/kernel_select_cpu.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
|
@ -63,43 +64,7 @@ void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) {
|
|||
auto &kernel_nodes = kernel_graph->execution_order();
|
||||
for (const auto &kernel_node : kernel_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
auto input_kernel_node = kernel_node->input(input_index + 1);
|
||||
MS_EXCEPTION_IF_NULL(input_kernel_node);
|
||||
if (!input_kernel_node->isa<Parameter>()) {
|
||||
continue;
|
||||
}
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
std::vector<std::string> output_formats = {kOpFormat_DEFAULT};
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
std::vector<TypeId> output_types{kNumberTypeFloat32};
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
|
||||
}
|
||||
|
||||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
MS_EXCEPTION_IF_NULL(builder);
|
||||
std::vector<std::string> input_formats;
|
||||
std::vector<TypeId> input_types;
|
||||
std::vector<std::string> output_formats;
|
||||
std::vector<TypeId> output_types;
|
||||
|
||||
for (size_t input_index = 0; input_index < input_num; ++input_index) {
|
||||
input_formats.emplace_back(kOpFormat_DEFAULT);
|
||||
input_types.emplace_back(kNumberTypeFloat32);
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
||||
output_formats.emplace_back(kOpFormat_DEFAULT);
|
||||
output_types.emplace_back(kNumberTypeFloat32);
|
||||
}
|
||||
builder->SetInputsFormat(input_formats);
|
||||
builder->SetInputsDeviceType(input_types);
|
||||
builder->SetOutputsFormat(output_formats);
|
||||
builder->SetOutputsDeviceType(output_types);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
|
||||
device::cpu::SetKernelInfo(kernel_node);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,7 +75,8 @@ void CPUSession::BuildKernel(const KernelGraph *kernel_graph) {
|
|||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
MS_LOG(INFO) << "Cpu building operator[" << kernel_name << "].";
|
||||
std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::Get().Create(kernel_name);
|
||||
std::shared_ptr<kernel::CPUKernel> cpu_kernel =
|
||||
kernel::CPUKernelFactory::GetInstance().Create(kernel_name, kernel_node);
|
||||
if (cpu_kernel == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Operator[" << kernel_name << "] is not support.";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue