!688 [MS]fix kernel select bug

Merge pull request !688 from chenjianping/fix-bugs
This commit is contained in:
mindspore-ci-bot 2020-04-26 20:21:23 +08:00 committed by Gitee
commit da8c74b54c
1 changed files with 9 additions and 11 deletions

View File

@ -218,16 +218,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
std::vector<TypeId> *node_mix_precision_datatype) {
AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(cur_input);
TypeId input_origin_type;
if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
// weight
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
} else if (cur_input->isa<ValueNode>()) {
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
} else {
// feature map
input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
}
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index);
node_mix_precision_datatype->push_back(input_origin_type);
}
@ -297,6 +288,12 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat
return !kernel_match_datatype_idx->empty();
}
bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index,
const std::vector<int> &node_mix_precision_datatype_index) {
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
}
bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
const std::vector<TypeId> &node_mix_precision_datatype,
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
@ -329,7 +326,7 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci
if (i >= datatype_indexes.size()) {
MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
}
if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) {
if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) {
iter = kernel_match_datatype_idx->erase(iter);
} else {
++iter;
@ -376,6 +373,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
kernel_support_datatype, kernel_match_datatype_idx);
if (selected_ret) {
*precision_reduce = false;
return;
}
if (context_ptr->enable_reduce_precision()) {