forked from mindspore-Ecosystem/mindspore
!688 [MS]fix kernel select bug
Merge pull request !688 from chenjianping/fix-bugs
This commit is contained in:
commit
da8c74b54c
|
@ -218,16 +218,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
|
|||
std::vector<TypeId> *node_mix_precision_datatype) {
|
||||
AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(cur_input);
|
||||
TypeId input_origin_type;
|
||||
if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
|
||||
// weight
|
||||
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
|
||||
} else if (cur_input->isa<ValueNode>()) {
|
||||
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
|
||||
} else {
|
||||
// feature map
|
||||
input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
|
||||
}
|
||||
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
|
||||
AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index);
|
||||
node_mix_precision_datatype->push_back(input_origin_type);
|
||||
}
|
||||
|
@ -297,6 +288,12 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat
|
|||
return !kernel_match_datatype_idx->empty();
|
||||
}
|
||||
|
||||
bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index,
|
||||
const std::vector<int> &node_mix_precision_datatype_index) {
|
||||
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
|
||||
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
|
||||
}
|
||||
|
||||
bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
|
||||
const std::vector<TypeId> &node_mix_precision_datatype,
|
||||
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
|
||||
|
@ -329,7 +326,7 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci
|
|||
if (i >= datatype_indexes.size()) {
|
||||
MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
|
||||
}
|
||||
if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) {
|
||||
if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) {
|
||||
iter = kernel_match_datatype_idx->erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
|
@ -376,6 +373,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
|
|||
bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
|
||||
kernel_support_datatype, kernel_match_datatype_idx);
|
||||
if (selected_ret) {
|
||||
*precision_reduce = false;
|
||||
return;
|
||||
}
|
||||
if (context_ptr->enable_reduce_precision()) {
|
||||
|
|
Loading…
Reference in New Issue