diff --git a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc index 549b97b61be..1efd3d6c22c 100644 --- a/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/device/ascend/kernel_select_ascend.cc @@ -218,16 +218,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, std::vector *node_mix_precision_datatype) { AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); MS_EXCEPTION_IF_NULL(cur_input); - TypeId input_origin_type; - if (cur_input->isa() && AnfAlgo::IsParameterWeight(cur_input->cast())) { - // weight - input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); - } else if (cur_input->isa()) { - input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0); - } else { - // feature map - input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); - } + TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); node_mix_precision_datatype->push_back(input_origin_type); } @@ -297,6 +288,12 @@ bool RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_dat return !kernel_match_datatype_idx->empty(); } +bool CanDataTypeReduce(const std::vector &datatype_indexes, int check_index, + const std::vector &node_mix_precision_datatype_index) { + return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && + datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; +} + bool RaiseOrReduceDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, const std::vector &node_mix_precision_datatype, const std::map> &kernel_support_datatypes, @@ -329,7 +326,7 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector &node_mix_preci if (i >= datatype_indexes.size()) { MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); } - if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) { + if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { iter = kernel_match_datatype_idx->erase(iter); } else { ++iter; @@ -376,6 +373,7 @@ void PrecisionReduce(const std::vector &node_mix_precision_datatype_index, bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, kernel_match_datatype_idx); if (selected_ret) { + *precision_reduce = false; return; } if (context_ptr->enable_reduce_precision()) {