!18222 fix code review problems

Merge pull request !18222 from chengxb7532/master
This commit is contained in:
i-robot 2021-06-16 09:17:31 +08:00 committed by Gitee
commit b18a1e5b91
4 changed files with 18 additions and 10 deletions

View File

@ -54,6 +54,9 @@ void CPUKernelFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_i
std::vector<KernelAttr> *kernel_attrs) {
auto inputs_ptr = op_info->inputs_ptr();
auto outputs_ptr = op_info->outputs_ptr();
if (inputs_ptr.empty()) {
MS_LOG(EXCEPTION) << "op " << op_info->op_name() << " input size is zero.";
}
auto first_input_dtypes = inputs_ptr[0]->dtypes();
auto input_formats = inputs_ptr[0]->formats();
@ -82,8 +85,7 @@ void CPUKernelFactory::UpdateKernelAttrs(const std::string &kernel_name, const s
std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size);
auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) {
MS_LOG(ERROR) << "CPUKernelFactory has not registered operator: " << kernel_name;
return;
MS_LOG(EXCEPTION) << "CPUKernelFactory has not registered operator: " << kernel_name;
}
if (attr_size <= iter->second.size()) {
@ -113,7 +115,7 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) {
MS_LOG(ERROR) << "Not find op[" << kernel_name << "] in cpu";
MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu";
}
kernel_attrs.clear();
SetKernelAttrs(op_info_ptr, &kernel_attrs);
@ -152,8 +154,7 @@ std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std::
std::vector<KernelAttr> result;
auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) {
MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!";
return result;
MS_LOG(EXCEPTION) << "Not registered CPU kernel: op[" << kernel_name << "]!";
}
auto creators = iter->second;
for (size_t index = 0; index < creators.size(); ++index) {

View File

@ -188,6 +188,7 @@ bool OpLib::RegOpFromLocalInfo() {
(void)OpLib::RegOp(line, "");
}
}
file.close();
MS_LOG(INFO) << "End";
return true;
}

View File

@ -244,7 +244,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) {
MS_LOG(ERROR) << "Not find op[" << op_name << "] in cpu";
MS_LOG(EXCEPTION) << "Not find op[" << op_name << "] in cpu";
}
kernel_attrs.clear();
kernel::CPUKernelFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
@ -260,8 +260,8 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
KernelNotSupportException(kernel_node, input_types, output_types);
}
matched = std::make_pair(false, false);
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes, output_types,
&matched, false);
(void)SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes,
output_types, &matched, false);
if (!matched.first) {
KernelNotSupportException(kernel_node, input_types, output_types);
}

View File

@ -24,10 +24,16 @@ minimum_grad_op_info = CpuRegOp("MinimumGrad") \
.output(1, "y2", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \
DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \
DataType.F64_Default, DataType.F64_Default) \
.get_op_info()