cpu kernel register type

This commit is contained in:
sunsuodong 2020-04-29 21:32:59 +08:00
parent e6273ce33b
commit 90fd53c1d8
23 changed files with 371 additions and 77 deletions

View File

@ -0,0 +1,143 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "device/cpu/kernel_select_cpu.h"
#include <string>
#include <memory>
#include <algorithm>
#include "kernel/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace device {
namespace cpu {
using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm;
using mindspore::kernel::KernelBuildInfo;
namespace {
bool IsInputNotCNode(const CNodePtr &kernel_node, size_t input_index) {
auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) {
return true;
}
return false;
}
void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vector<size_t> &input_not_cnode_indexes,
const CNodePtr kernel_node) {
for (auto &input_index : input_not_cnode_indexes) {
auto input_node = AnfAlgo::VisitKernel(kernel_node->input(input_index + 1), 0).first;
MS_EXCEPTION_IF_NULL(input_node);
std::vector<TypeId> output_types;
output_types.emplace_back(kernel_attr.GetInputAttr(input_index).first);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
builder->SetOutputsFormat({kOpFormat_DEFAULT});
builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_node.get());
}
}
void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *input_formats,
std::vector<TypeId> *input_types, std::vector<size_t> *input_no_cnode_indexes) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
TypeId dtype = kTypeUnknown;
if (IsInputNotCNode(kernel_node, input_index)) {
input_no_cnode_indexes->emplace_back(input_index);
} else {
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
}
input_formats->emplace_back(kOpFormat_DEFAULT);
input_types->emplace_back(dtype);
}
}
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr,
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (kernel_attr.GetOutputSize() != output_num) {
MS_LOG(EXCEPTION) << "Output num is not equal!";
}
for (size_t output_index = 0; output_index < output_num; ++output_index) {
output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second);
auto dtype = kernel_attr.GetOutputAttr(output_index).first;
output_types->emplace_back(dtype);
}
}
bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<std::string> &input_formats,
const std::vector<TypeId> &input_types,
const std::vector<size_t> &input_not_cnode_indexes) {
if (kernel_attr.GetInputSize() != input_types.size()) {
MS_LOG(ERROR) << "Output num is not equal!";
return false;
}
auto input_num = input_types.size();
for (size_t i = 0; i < input_num; ++i) {
bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(),
[i](size_t index) { return index == i; });
if (is_not_cnode_idx) {
continue;
}
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
MS_LOG(ERROR) << "reg dtype=" << kernel_attr.GetInputAttr(i).first << ", input dtype=" << input_types[i];
return false;
}
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
MS_LOG(ERROR) << "reg format=" << kernel_attr.GetInputAttr(i).second << ", input format=" << input_formats[i];
return false;
}
}
return true;
}
} // namespace
void SetKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<size_t> input_not_cnode_indexes;
std::vector<std::string> output_formats;
std::vector<TypeId> output_types;
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
auto kernel_attrs =
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
for (auto &kernel_attr : kernel_attrs) {
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
for (auto &input_index : input_not_cnode_indexes) {
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
}
break;
}
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
builder->SetInputsFormat(input_formats);
builder->SetInputsDeviceType(input_types);
builder->SetOutputsFormat(output_formats);
builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
}
} // namespace cpu
} // namespace device
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_
#define MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/dtype/type.h"
#include "utils/utils.h"
namespace mindspore {
namespace device {
namespace cpu {
void SetKernelInfo(const CNodePtr &apply_kernel_ptr);
class KernelAttr {
public:
using DataType = std::pair<TypeId, std::string>;
KernelAttr() = default;
~KernelAttr() = default;
KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
input_type_.emplace_back(ms_type, format);
return *this;
}
KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) {
output_type_.emplace_back(ms_type, format);
return *this;
}
const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; }
const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; }
size_t GetInputSize() const { return input_type_.size(); }
size_t GetOutputSize() const { return output_type_.size(); }
private:
std::vector<DataType> input_type_;
std::vector<DataType> output_type_;
};
} // namespace cpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_DEVICE_CPU_KERNEL_SELECT_CPU_H_

View File

@ -33,7 +33,15 @@ class ApplyMomentumCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(ApplyMomentum, ApplyMomentumCPUKernel);
MS_REG_CPU_KERNEL(ApplyMomentum,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
ApplyMomentumCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -37,7 +37,8 @@ class ArgmaxCPUKernel : public CPUKernel {
size_t batch_size_{0};
};
MS_REG_CPU_KERNEL(Argmax, ArgmaxCPUKernel);
MS_REG_CPU_KERNEL(Argmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArgmaxCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -37,7 +37,10 @@ class BiasAddCPUKernel : public CPUKernel {
std::vector<size_t> input_shape_;
std::vector<size_t> bias_shape_;
};
MS_REG_CPU_KERNEL(BiasAdd, BiasAddCPUKernel);
MS_REG_CPU_KERNEL(
BiasAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BiasAddCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIAS_ADD_CPU_KERNEL_H_

View File

@ -36,7 +36,8 @@ class BiasAddGradCPUKernel : public CPUKernel {
private:
std::vector<size_t> input_shape_;
};
MS_REG_CPU_KERNEL(BiasAddGrad, BiasAddGradCPUKernel);
MS_REG_CPU_KERNEL(BiasAddGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BiasAddGradCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_MINDSPORE_CCSRC_KERNEL_CPU_BIASADDGRADCPUKERNEL_H_

View File

@ -20,29 +20,79 @@
#include <iostream>
#include <string>
#include "device/kernel_info.h"
namespace mindspore {
namespace kernel {
CPUKernelFactory &CPUKernelFactory::Get() {
CPUKernelFactory &CPUKernelFactory::GetInstance() {
static CPUKernelFactory instance;
return instance;
}
void CPUKernelFactory::Register(const std::string &kernel_name, CPUKernelCreator &&kernel_creator) {
if (kernel_creators_.find(kernel_name) == kernel_creators_.end()) {
(void)kernel_creators_.emplace(kernel_name, kernel_creator);
void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr &kernel_attr,
CPUKernelCreator &&kernel_creator) {
(void)name_to_attr_creator_[kernel_name].emplace_back(kernel_attr, kernel_creator);
#if !defined(_WIN32) && !defined(_WIN64)
MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name;
MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name;
#endif
}
}
std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name) {
auto iter = kernel_creators_.find(kernel_name);
if (iter != kernel_creators_.end()) {
MS_EXCEPTION_IF_NULL(iter->second);
return (iter->second)();
std::shared_ptr<CPUKernel> CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) {
auto kernel_info = apply_kernel->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(kernel_build_Info);
std::pair<bool, size_t> ret_pair = CPUKernelAttrCheck(kernel_name, kernel_build_Info);
if (ret_pair.first) {
return (name_to_attr_creator_.find(kernel_name)->second)[ret_pair.second].second();
}
return nullptr;
}
std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &kernel_name,
const KernelBuildInfo *kernel_info) {
auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) {
MS_LOG(INFO) << "Not registered CPU kernel: op[" << kernel_name << "]!";
return std::make_pair(false, 0);
}
auto creators = iter->second;
for (size_t index = 0; index < creators.size(); ++index) {
auto attr_creator = creators[index];
for (size_t i = 0; i < kernel_info->GetInputNum(); ++i) {
if (kernel_info->GetInputDeviceType(i) != attr_creator.first.GetInputAttr(i).first) {
MS_LOG(WARNING) << "cpu kernel attr check failed. input index: " << i << ".";
MS_LOG(WARNING) << "kernel info type:" << kernel_info->GetInputDeviceType(i) << ", "
<< "register type:" << attr_creator.first.GetInputAttr(i).first;
return std::make_pair(false, 0);
}
}
for (size_t i = 0; i < kernel_info->GetOutputNum(); ++i) {
if (kernel_info->GetOutputDeviceType(i) != attr_creator.first.GetOutputAttr(i).first) {
MS_LOG(WARNING) << "cpu kernel attr check failed. output index: " << i << ".";
MS_LOG(WARNING) << "kernel info type:" << kernel_info->GetOutputDeviceType(i) << ", "
<< "register type:" << attr_creator.first.GetOutputAttr(i).first;
return std::make_pair(false, 0);
}
}
return std::make_pair(true, index);
}
return std::make_pair(false, 0);
}
std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std::string &kernel_name) {
std::vector<KernelAttr> result;
auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) {
MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!";
return result;
}
auto creators = iter->second;
for (size_t index = 0; index < creators.size(); ++index) {
auto attr_creator = creators[index];
result.push_back(attr_creator.first);
}
return result;
}
} // namespace kernel
} // namespace mindspore

View File

@ -21,35 +21,54 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "common/utils.h"
#include "kernel/cpu/cpu_kernel.h"
#include "device/cpu/kernel_select_cpu.h"
namespace mindspore {
namespace kernel {
using mindspore::device::cpu::KernelAttr;
using CPUKernelCreator = std::function<std::shared_ptr<CPUKernel>()>;
class CPUKernelFactory {
public:
static CPUKernelFactory &Get();
void Register(const std::string &kernel_name, CPUKernelCreator &&kernel_creator);
static CPUKernelFactory &GetInstance();
void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator);
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name);
std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel);
std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name);
private:
CPUKernelFactory() = default;
~CPUKernelFactory() = default;
DISABLE_COPY_AND_ASSIGN(CPUKernelFactory)
std::map<std::string, CPUKernelCreator> kernel_creators_;
std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info);
std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_;
};
class CPUKernelRegistrar {
public:
CPUKernelRegistrar(const std::string &kernel_name, CPUKernelCreator &&kernel_creator) {
CPUKernelFactory::Get().Register(kernel_name, std::move(kernel_creator));
CPUKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator) {
CPUKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator));
}
~CPUKernelRegistrar() = default;
};
#define MS_REG_CPU_KERNEL(KERNEL_NAME, KERNEL_CLASS) \
static const CPUKernelRegistrar g_cpu_kernel_##KERNEL_NAME##_reg(#KERNEL_NAME, \
[]() { return std::make_shared<KERNEL_CLASS>(); });
#define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) \
static_assert(std::is_base_of<CPUKernel, OPCLASS>::value, " must be base of CPUKernel"); \
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_reg(#OPNAME, ATTR, \
[]() { return std::make_shared<OPCLASS>(); });
#define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) \
static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_reg(#OPNAME, ATTR, \
[]() { return std::make_shared<OPCLASS<T>>(); });
#define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \
static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \
static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \
#OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T, S>>(); });
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,10 @@ class EqualCountCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(EqualCount, EqualCountCPUKernel);
MS_REG_CPU_KERNEL(
EqualCount,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EqualCountCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,10 @@ class Conv2dCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(Conv2D, Conv2dCPUKernel);
MS_REG_CPU_KERNEL(
Conv2D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Conv2dCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,10 @@ class Conv2dGradFilterCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(Conv2DBackpropFilter, Conv2dGradFilterCPUKernel);
MS_REG_CPU_KERNEL(
Conv2DBackpropFilter,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Conv2dGradFilterCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,10 @@ class Conv2dGradInputCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(Conv2DBackpropInput, Conv2dGradInputCPUKernel);
MS_REG_CPU_KERNEL(
Conv2DBackpropInput,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Conv2dGradInputCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -40,7 +40,10 @@ class MatMulCPUKernel : public MKLCPUKernel {
dnnl_dim_t dim_k_{0};
};
MS_REG_CPU_KERNEL(MatMul, MatMulCPUKernel);
MS_REG_CPU_KERNEL(
MatMul,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MatMulCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,9 @@ class MulCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(Mul, MulCPUKernel);
MS_REG_CPU_KERNEL(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MulCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,8 @@ class PoolingCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(MaxPool, PoolingCPUKernel);
MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -43,7 +43,13 @@ class PoolingGradCPUKernel : public MKLCPUKernel {
std::vector<size_t> dst_shape_;
};
MS_REG_CPU_KERNEL(MaxPoolGrad, PoolingGradCPUKernel);
MS_REG_CPU_KERNEL(MaxPoolGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
PoolingGradCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,7 @@ class ReluCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(ReLU, ReluCPUKernel);
MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReluCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,10 @@ class ReluGradCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(ReluGrad, ReluGradCPUKernel);
MS_REG_CPU_KERNEL(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReluGradCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,7 +33,8 @@ class SoftmaxCPUKernel : public MKLCPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(Softmax, SoftmaxCPUKernel);
MS_REG_CPU_KERNEL(Softmax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SoftmaxCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -43,7 +43,10 @@ class SparseSoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel {
size_t batch_size_{0};
};
MS_REG_CPU_KERNEL(SparseSoftmaxCrossEntropyWithLogits, SparseSoftmaxCrossEntropyWithLogitsCPUKernel);
MS_REG_CPU_KERNEL(
SparseSoftmaxCrossEntropyWithLogits,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
SparseSoftmaxCrossEntropyWithLogitsCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -38,7 +38,13 @@ class OneHotCPUKernel : public CPUKernel {
size_t axis_;
};
MS_REG_CPU_KERNEL(OneHot, OneHotCPUKernel);
MS_REG_CPU_KERNEL(OneHot,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
OneHotCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -33,9 +33,12 @@ class ReshapeCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(Reshape, ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Flatten, ReshapeCPUKernel);
MS_REG_CPU_KERNEL(ExpandDims, ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Reshape, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(Flatten, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
MS_REG_CPU_KERNEL(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReshapeCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include "device/kernel_runtime.h"
#include "predict/predict.h"
#include "kernel/cpu/cpu_kernel_factory.h"
#include "device/cpu/kernel_select_cpu.h"
namespace mindspore {
namespace session {
@ -63,43 +64,7 @@ void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) {
auto &kernel_nodes = kernel_graph->execution_order();
for (const auto &kernel_node : kernel_nodes) {
MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
auto input_kernel_node = kernel_node->input(input_index + 1);
MS_EXCEPTION_IF_NULL(input_kernel_node);
if (!input_kernel_node->isa<Parameter>()) {
continue;
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
std::vector<std::string> output_formats = {kOpFormat_DEFAULT};
builder->SetOutputsFormat(output_formats);
std::vector<TypeId> output_types{kNumberTypeFloat32};
builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
}
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
MS_EXCEPTION_IF_NULL(builder);
std::vector<std::string> input_formats;
std::vector<TypeId> input_types;
std::vector<std::string> output_formats;
std::vector<TypeId> output_types;
for (size_t input_index = 0; input_index < input_num; ++input_index) {
input_formats.emplace_back(kOpFormat_DEFAULT);
input_types.emplace_back(kNumberTypeFloat32);
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
output_formats.emplace_back(kOpFormat_DEFAULT);
output_types.emplace_back(kNumberTypeFloat32);
}
builder->SetInputsFormat(input_formats);
builder->SetInputsDeviceType(input_types);
builder->SetOutputsFormat(output_formats);
builder->SetOutputsDeviceType(output_types);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node.get());
device::cpu::SetKernelInfo(kernel_node);
}
}
@ -110,7 +75,8 @@ void CPUSession::BuildKernel(const KernelGraph *kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
MS_LOG(INFO) << "Cpu building operator[" << kernel_name << "].";
std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::Get().Create(kernel_name);
std::shared_ptr<kernel::CPUKernel> cpu_kernel =
kernel::CPUKernelFactory::GetInstance().Create(kernel_name, kernel_node);
if (cpu_kernel == nullptr) {
MS_LOG(EXCEPTION) << "Operator[" << kernel_name << "] is not support.";
}