erase datatype raise kernel

This commit is contained in:
chenjianping 2020-04-26 10:52:53 +00:00
parent 001912237e
commit 6d47036f95
1 changed files with 9 additions and 11 deletions

View File

@ -218,16 +218,7 @@ void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
std::vector<TypeId> *node_mix_precision_datatype) { std::vector<TypeId> *node_mix_precision_datatype) {
AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index);
MS_EXCEPTION_IF_NULL(cur_input); MS_EXCEPTION_IF_NULL(cur_input);
TypeId input_origin_type; TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
// weight
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
} else if (cur_input->isa<ValueNode>()) {
input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
} else {
// feature map
input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
}
AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index);
node_mix_precision_datatype->push_back(input_origin_type); node_mix_precision_datatype->push_back(input_origin_type);
} }
@ -297,6 +288,12 @@ bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_dat
return !kernel_match_datatype_idx->empty(); return !kernel_match_datatype_idx->empty();
} }
bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index,
const std::vector<int> &node_mix_precision_datatype_index) {
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
}
bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index, bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
const std::vector<TypeId> &node_mix_precision_datatype, const std::vector<TypeId> &node_mix_precision_datatype,
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes, const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
@ -329,7 +326,7 @@ bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_preci
if (i >= datatype_indexes.size()) { if (i >= datatype_indexes.size()) {
MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
} }
if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) { if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) {
iter = kernel_match_datatype_idx->erase(iter); iter = kernel_match_datatype_idx->erase(iter);
} else { } else {
++iter; ++iter;
@ -376,6 +373,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
kernel_support_datatype, kernel_match_datatype_idx); kernel_support_datatype, kernel_match_datatype_idx);
if (selected_ret) { if (selected_ret) {
*precision_reduce = false;
return; return;
} }
if (context_ptr->enable_reduce_precision()) { if (context_ptr->enable_reduce_precision()) {