!2401 fix code review

Merge pull request !2401 from lianliguang/fix-code-review
This commit is contained in:
mindspore-ci-bot 2020-06-22 19:01:12 +08:00 committed by Gitee
commit b91d32708b
2 changed files with 22 additions and 17 deletions

View File

@ -106,7 +106,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
MS_EXCEPTION_IF_NULL(best_item);
if (cur_item.size() != best_item->size()) {
MS_LOG(ERROR) << "item size should be same!";
MS_LOG(ERROR) << "Item size should be same!";
return false;
}
// Update the best_item by comparing the cur_item and best_item
@ -280,8 +280,12 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat
bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index,
const std::vector<int> &node_mix_precision_datatype_index) {
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
auto check_index_tmp = IntToSize(check_index);
if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) {
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
}
MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range";
}
bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
@ -300,10 +304,10 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci
if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
auto find_iter = kernel_support_datatypes.find(iter->first);
if (find_iter == kernel_support_datatypes.end()) {
MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first;
MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first;
}
if (i >= find_iter->second.size()) {
MS_LOG(EXCEPTION) << "node index " << i << " >= kernel datatype size " << find_iter->second.size();
MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size();
}
if (node_mix_precision_datatype[i] != find_iter->second[i]) {
iter = kernel_match_datatype_idx->erase(iter);
@ -314,7 +318,7 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci
}
auto datatype_indexes = iter->second;
if (i >= datatype_indexes.size()) {
MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
}
if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) {
iter = kernel_match_datatype_idx->erase(iter);
@ -384,9 +388,9 @@ void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
std::ostringstream buffer;
buffer << cnode->DebugString();
if (precision_reduce) {
buffer << " reduce precision, node datatype: \n";
buffer << " Reduce precision, node datatype: \n";
} else {
buffer << " raise precision, node datatype: \n";
buffer << " Raise precision, node datatype: \n";
}
PrintInputAndOutputInferType(buffer, cnode);
buffer << ", select kernel:" << selected_kernel_build_info->ToString();
@ -554,12 +558,12 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kern
if (select_status == kNoMatched) {
std::ostringstream buffer;
PrintInputAndOutputInferType(buffer, kernel_node);
MS_LOG(WARNING) << ">>> candidates kernel info list:";
MS_LOG(WARNING) << ">>> Candidates kernel info list:";
for (size_t index = 0; index < kernel_info_list.size(); ++index) {
MS_LOG(WARNING) << "kernel [" << index << "] :" << kernel_info_list[index]->ToString();
MS_LOG(WARNING) << "Kernel [" << index << "] :" << kernel_info_list[index]->ToString();
}
for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) {
MS_LOG(WARNING) << "kernel [" << (kernel_info_list.size() + index)
MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index)
<< "] :" << aicpu_kernel_info_list[index]->ToString();
}
MS_LOG(WARNING) << " <<<";

View File

@ -32,7 +32,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
MS_EXCEPTION_IF_NULL(kernel_info_list);
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_list;
(void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list),
[&](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
[&kernel_node](const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info) {
return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() &&
AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum();
});
@ -43,15 +43,16 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
MS_LOG(INFO) << "All kernel Info list does not match any kernel info ";
for (size_t index = 0; index < kernel_info_list->size(); ++index) {
std::ostringstream buffer;
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info_list->at(index)->GetOutputNum()) {
auto kernel_info = kernel_info_list->at(index);
MS_EXCEPTION_IF_NULL(kernel_info);
if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) {
buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< " cannot match the kernel's output size [" << kernel_info_list->at(index)->GetOutputNum() << "]";
<< " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]";
} else {
buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]"
<< " cannot match the kernel's output size [" << kernel_info_list->at(index)->GetInputNum() << "]";
<< " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]";
}
MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString() << buffer.str();
MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str();
}
kernel_info_list->clear();
MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : ["