forked from mindspore-Ecosystem/mindspore
!18222 fix code review problems
Merge pull request !18222 from chengxb7532/master
This commit is contained in:
commit
b18a1e5b91
|
@ -54,6 +54,9 @@ void CPUKernelFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_i
|
||||||
std::vector<KernelAttr> *kernel_attrs) {
|
std::vector<KernelAttr> *kernel_attrs) {
|
||||||
auto inputs_ptr = op_info->inputs_ptr();
|
auto inputs_ptr = op_info->inputs_ptr();
|
||||||
auto outputs_ptr = op_info->outputs_ptr();
|
auto outputs_ptr = op_info->outputs_ptr();
|
||||||
|
if (inputs_ptr.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "op " << op_info->op_name() << " input size is zero.";
|
||||||
|
}
|
||||||
auto first_input_dtypes = inputs_ptr[0]->dtypes();
|
auto first_input_dtypes = inputs_ptr[0]->dtypes();
|
||||||
auto input_formats = inputs_ptr[0]->formats();
|
auto input_formats = inputs_ptr[0]->formats();
|
||||||
|
|
||||||
|
@ -82,8 +85,7 @@ void CPUKernelFactory::UpdateKernelAttrs(const std::string &kernel_name, const s
|
||||||
std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size);
|
std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size);
|
||||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||||
if (iter == name_to_attr_creator_.end()) {
|
if (iter == name_to_attr_creator_.end()) {
|
||||||
MS_LOG(ERROR) << "CPUKernelFactory has not registered operator: " << kernel_name;
|
MS_LOG(EXCEPTION) << "CPUKernelFactory has not registered operator: " << kernel_name;
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (attr_size <= iter->second.size()) {
|
if (attr_size <= iter->second.size()) {
|
||||||
|
@ -113,7 +115,7 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
|
||||||
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
|
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
|
||||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
|
||||||
if (op_info_ptr == nullptr) {
|
if (op_info_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "Not find op[" << kernel_name << "] in cpu";
|
MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu";
|
||||||
}
|
}
|
||||||
kernel_attrs.clear();
|
kernel_attrs.clear();
|
||||||
SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
||||||
|
@ -152,8 +154,7 @@ std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std::
|
||||||
std::vector<KernelAttr> result;
|
std::vector<KernelAttr> result;
|
||||||
auto iter = name_to_attr_creator_.find(kernel_name);
|
auto iter = name_to_attr_creator_.find(kernel_name);
|
||||||
if (iter == name_to_attr_creator_.end()) {
|
if (iter == name_to_attr_creator_.end()) {
|
||||||
MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
MS_LOG(EXCEPTION) << "Not registered CPU kernel: op[" << kernel_name << "]!";
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
auto creators = iter->second;
|
auto creators = iter->second;
|
||||||
for (size_t index = 0; index < creators.size(); ++index) {
|
for (size_t index = 0; index < creators.size(); ++index) {
|
||||||
|
|
|
@ -188,6 +188,7 @@ bool OpLib::RegOpFromLocalInfo() {
|
||||||
(void)OpLib::RegOp(line, "");
|
(void)OpLib::RegOp(line, "");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
file.close();
|
||||||
MS_LOG(INFO) << "End";
|
MS_LOG(INFO) << "End";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -244,7 +244,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
|
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
|
||||||
if (op_info_ptr == nullptr) {
|
if (op_info_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "Not find op[" << op_name << "] in cpu";
|
MS_LOG(EXCEPTION) << "Not find op[" << op_name << "] in cpu";
|
||||||
}
|
}
|
||||||
kernel_attrs.clear();
|
kernel_attrs.clear();
|
||||||
kernel::CPUKernelFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
kernel::CPUKernelFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
|
||||||
|
@ -260,8 +260,8 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
||||||
KernelNotSupportException(kernel_node, input_types, output_types);
|
KernelNotSupportException(kernel_node, input_types, output_types);
|
||||||
}
|
}
|
||||||
matched = std::make_pair(false, false);
|
matched = std::make_pair(false, false);
|
||||||
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes, output_types,
|
(void)SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes,
|
||||||
&matched, false);
|
output_types, &matched, false);
|
||||||
if (!matched.first) {
|
if (!matched.first) {
|
||||||
KernelNotSupportException(kernel_node, input_types, output_types);
|
KernelNotSupportException(kernel_node, input_types, output_types);
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,10 +24,16 @@ minimum_grad_op_info = CpuRegOp("MinimumGrad") \
|
||||||
.output(1, "y2", "required") \
|
.output(1, "y2", "required") \
|
||||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
|
||||||
DataType.I32_Default, DataType.I32_Default) \
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
|
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \
|
||||||
DataType.F16_Default, DataType.F16_Default) \
|
DataType.U32_Default, DataType.U32_Default) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
|
||||||
DataType.F32_Default, DataType.F32_Default) \
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \
|
||||||
|
DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \
|
||||||
|
DataType.F64_Default, DataType.F64_Default) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue