!18222 fix code review problems

Merge pull request !18222 from chengxb7532/master
This commit is contained in:
i-robot 2021-06-16 09:17:31 +08:00 committed by Gitee
commit b18a1e5b91
4 changed files with 18 additions and 10 deletions

View File

@ -54,6 +54,9 @@ void CPUKernelFactory::SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_i
std::vector<KernelAttr> *kernel_attrs) { std::vector<KernelAttr> *kernel_attrs) {
auto inputs_ptr = op_info->inputs_ptr(); auto inputs_ptr = op_info->inputs_ptr();
auto outputs_ptr = op_info->outputs_ptr(); auto outputs_ptr = op_info->outputs_ptr();
if (inputs_ptr.empty()) {
MS_LOG(EXCEPTION) << "op " << op_info->op_name() << " input size is zero.";
}
auto first_input_dtypes = inputs_ptr[0]->dtypes(); auto first_input_dtypes = inputs_ptr[0]->dtypes();
auto input_formats = inputs_ptr[0]->formats(); auto input_formats = inputs_ptr[0]->formats();
@ -82,8 +85,7 @@ void CPUKernelFactory::UpdateKernelAttrs(const std::string &kernel_name, const s
std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size); std::vector<std::pair<KernelAttr, CPUKernelCreator>> attr_creators(attr_size);
auto iter = name_to_attr_creator_.find(kernel_name); auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) { if (iter == name_to_attr_creator_.end()) {
MS_LOG(ERROR) << "CPUKernelFactory has not registered operator: " << kernel_name; MS_LOG(EXCEPTION) << "CPUKernelFactory has not registered operator: " << kernel_name;
return;
} }
if (attr_size <= iter->second.size()) { if (attr_size <= iter->second.size()) {
@ -113,7 +115,7 @@ std::pair<bool, size_t> CPUKernelFactory::CPUKernelAttrCheck(const std::string &
if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) { if (kernel_attrs[0].GetInputSize() == 0 && kernel_attrs[0].GetOutputSize() == 0) {
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU); auto op_info_ptr = mindspore::kernel::OpLib::FindOp(kernel_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) { if (op_info_ptr == nullptr) {
MS_LOG(ERROR) << "Not find op[" << kernel_name << "] in cpu"; MS_LOG(EXCEPTION) << "Not find op[" << kernel_name << "] in cpu";
} }
kernel_attrs.clear(); kernel_attrs.clear();
SetKernelAttrs(op_info_ptr, &kernel_attrs); SetKernelAttrs(op_info_ptr, &kernel_attrs);
@ -152,8 +154,7 @@ std::vector<KernelAttr> CPUKernelFactory::GetSupportedKernelAttrList(const std::
std::vector<KernelAttr> result; std::vector<KernelAttr> result;
auto iter = name_to_attr_creator_.find(kernel_name); auto iter = name_to_attr_creator_.find(kernel_name);
if (iter == name_to_attr_creator_.end()) { if (iter == name_to_attr_creator_.end()) {
MS_LOG(WARNING) << "Not registered CPU kernel: op[" << kernel_name << "]!"; MS_LOG(EXCEPTION) << "Not registered CPU kernel: op[" << kernel_name << "]!";
return result;
} }
auto creators = iter->second; auto creators = iter->second;
for (size_t index = 0; index < creators.size(); ++index) { for (size_t index = 0; index < creators.size(); ++index) {

View File

@ -188,6 +188,7 @@ bool OpLib::RegOpFromLocalInfo() {
(void)OpLib::RegOp(line, ""); (void)OpLib::RegOp(line, "");
} }
} }
file.close();
MS_LOG(INFO) << "End"; MS_LOG(INFO) << "End";
return true; return true;
} }

View File

@ -244,7 +244,7 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
std::string op_name = AnfAlgo::GetCNodeName(kernel_node); std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU); auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, kernel::OpImplyType::kCPU);
if (op_info_ptr == nullptr) { if (op_info_ptr == nullptr) {
MS_LOG(ERROR) << "Not find op[" << op_name << "] in cpu"; MS_LOG(EXCEPTION) << "Not find op[" << op_name << "] in cpu";
} }
kernel_attrs.clear(); kernel_attrs.clear();
kernel::CPUKernelFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs); kernel::CPUKernelFactory::GetInstance().SetKernelAttrs(op_info_ptr, &kernel_attrs);
@ -260,8 +260,8 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
KernelNotSupportException(kernel_node, input_types, output_types); KernelNotSupportException(kernel_node, input_types, output_types);
} }
matched = std::make_pair(false, false); matched = std::make_pair(false, false);
SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes, output_types, (void)SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_types, input_not_cnode_indexes,
&matched, false); output_types, &matched, false);
if (!matched.first) { if (!matched.first) {
KernelNotSupportException(kernel_node, input_types, output_types); KernelNotSupportException(kernel_node, input_types, output_types);
} }

View File

@ -24,10 +24,16 @@ minimum_grad_op_info = CpuRegOp("MinimumGrad") \
.output(1, "y2", "required") \ .output(1, "y2", "required") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, \
DataType.I32_Default, DataType.I32_Default) \ DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, \ .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default, \
DataType.F16_Default, DataType.F16_Default) \ DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, \
DataType.F32_Default, DataType.F32_Default) \ DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default, \
DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default, \
DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, \
DataType.F64_Default, DataType.F64_Default) \
.get_op_info() .get_op_info()