From 020c71e72c31bede6ca21c40ea4c5e973993d3ff Mon Sep 17 00:00:00 2001 From: liubuyu Date: Fri, 16 Oct 2020 19:59:02 +0800 Subject: [PATCH] reduce or raise precision restructure --- .../device/ascend/kernel_select_ascend.cc | 316 +++++------------- 1 file changed, 78 insertions(+), 238 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index cd10ee042ae..fe03cd92105 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -16,23 +16,23 @@ #include "runtime/device/ascend/kernel_select_ascend.h" -#include -#include -#include -#include #include #include +#include +#include #include #include -#include "utils/ms_utils.h" +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" +#include "backend/kernel_compiler/kernel_query.h" +#include "backend/kernel_compiler/oplib/oplib.h" #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" +#include "backend/session/anf_runtime_algorithm.h" #include "debug/anf_ir_dump.h" #include "frontend/operator/ops.h" #include "utils/ms_context.h" -#include "backend/session/anf_runtime_algorithm.h" -#include "backend/kernel_compiler/kernel_query.h" -#include "backend/kernel_compiler/oplib/oplib.h" -#include "backend/kernel_compiler/kernel_build_info.h" +#include "utils/ms_utils.h" namespace mindspore { namespace device { @@ -172,218 +172,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons } } -void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector *support_index) { - MS_EXCEPTION_IF_NULL(support_index); - int index = kUnSupportMixedDataTypeIndex; - switch (data_type) { - case kNumberTypeFloat16: - index = 0; - break; - case kNumberTypeFloat32: - case kNumberTypeFloat: - index = 1; - break; - default: - break; - } - support_index->push_back(index); -} - -void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index, - std::vector *support_datatype_index, std::vector *support_datatype) { - MS_EXCEPTION_IF_NULL(support_datatype); - auto data_type = kernel_build_info.GetInputDeviceType(input_index); - support_datatype->push_back(data_type); - AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); -} - -void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index, - std::vector *support_datatype_index, std::vector *support_datatype) { - MS_EXCEPTION_IF_NULL(support_datatype); - auto data_type = kernel_build_info.GetOutputDeviceType(output_index); - support_datatype->push_back(data_type); - AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index); -} - -void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index, - std::vector *node_mix_precision_datatype_index, - std::vector *node_mix_precision_datatype) { - AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index); - MS_EXCEPTION_IF_NULL(cur_input); - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index); - AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index); - node_mix_precision_datatype->push_back(input_origin_type); -} - -void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index, - std::vector *node_mix_precision_datatype_index, - std::vector *node_mix_precision_datatype) { - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index); - AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index); - node_mix_precision_datatype->push_back(output_origin_type); -} - -void CheckDataTypeInputs(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) { - MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size " - << node_mix_precision_datatype.size(); - } - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) { - MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size " - << kernel_support_datatypes.size(); - } -} - -bool RaiseDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, - kernel_match_datatype_idx); - for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { - if (node_mix_precision_datatype[i] == kTypeUnknown) { - continue; - } - auto iter = kernel_match_datatype_idx->begin(); - while (iter != kernel_match_datatype_idx->end()) { - if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { - auto find_iter = kernel_support_datatypes.find(iter->first); - if (find_iter == kernel_support_datatypes.end()) { - MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; - } - if (i >= find_iter->second.size()) { - MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size(); - } - if (node_mix_precision_datatype[i] != find_iter->second[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - continue; - } - auto datatype_indexes = iter->second; - if (i >= datatype_indexes.size()) { - MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size(); - } - if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - } - } - return !kernel_match_datatype_idx->empty(); -} - -bool CanDataTypeReduce(const std::vector &datatype_indexes, int check_index, - const std::vector &node_mix_precision_datatype_index) { - auto check_index_tmp = IntToSize(check_index); - if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) { - return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex && - datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index]; - } - MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range"; -} - -bool RaiseOrReduceDataTypePrecisionSelect(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatypes, - std::map> *kernel_match_datatype_idx) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes, - kernel_match_datatype_idx); - for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) { - if (node_mix_precision_datatype[i] == kTypeUnknown) { - continue; - } - auto iter = kernel_match_datatype_idx->begin(); - while (iter != kernel_match_datatype_idx->end()) { - if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) { - auto find_iter = kernel_support_datatypes.find(iter->first); - if (find_iter == kernel_support_datatypes.end()) { - MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first; - } - if (i >= find_iter->second.size()) { - MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size(); - } - if (node_mix_precision_datatype[i] != find_iter->second[i]) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - continue; - } - auto datatype_indexes = iter->second; - if (i >= datatype_indexes.size()) { - MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size(); - } - if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) { - iter = kernel_match_datatype_idx->erase(iter); - } else { - ++iter; - } - } - } - return !kernel_match_datatype_idx->empty(); -} - -void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info, - std::vector *support_indexes, std::vector *node_mix_precision_datatype, - std::vector *support_datatypes, - std::vector *node_mix_precision_datatype_index) { - MS_EXCEPTION_IF_NULL(node_mix_precision_datatype); - bool add_node_datatype_flag = false; - if (node_mix_precision_datatype->empty()) { - add_node_datatype_flag = true; - } - for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) { - AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes); - if (add_node_datatype_flag) { - AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype); - } - } - // Check output data type - for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) { - AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes); - if (add_node_datatype_flag) { - AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype); - } - } -} - -void PrecisionReduce(const std::vector &node_mix_precision_datatype_index, - const std::vector &node_mix_precision_datatype, - const std::map> &kernel_support_datatype, - std::map> *kernel_match_datatype_idx, bool *precision_reduce) { - MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - MS_EXCEPTION_IF_NULL(precision_reduce); - std::map> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx; - // raise precision - bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, kernel_match_datatype_idx); - if (selected_ret) { - *precision_reduce = false; - return; - } - if (context_ptr->get_param(MS_CTX_ENABLE_REDUCE_PRECISION)) { - selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype, - kernel_support_datatype, &kernel_match_datatype_idx_copy); - } - if (selected_ret) { - *precision_reduce = true; - *kernel_match_datatype_idx = kernel_match_datatype_idx_copy; - } -} - void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode, const std::shared_ptr &selected_kernel_build_info, bool precision_reduce) { @@ -434,30 +222,82 @@ std::vector> FilteredKernelInfoByDtype( return result; } +bool TagRaiseReduce(const std::shared_ptr &kernel_build_info, const CNodePtr &cnode, + const std::map &type_map) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(kernel_build_info); + size_t flag_in = 0; + size_t flag_out = 0; + for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) { + auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index); + auto device_dtype = kernel_build_info->GetInputDeviceType(input_index); + if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) { + device_dtype = kNumberTypeFloat32; + } + auto iter = type_map.find(in_dtype); + if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { + return false; + } + if (iter == type_map.end() && in_dtype != device_dtype) { + flag_in += 1; + } + } + + for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) { + auto in_dtype = AnfAlgo::GetOutputInferDataType(cnode, output_index); + auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index); + if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) { + device_dtype = kNumberTypeFloat32; + } + auto iter = type_map.find(in_dtype); + if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) { + return false; + } + if (iter == type_map.end() && in_dtype != device_dtype) { + flag_out += 1; + } + } + if (flag_in == kernel_build_info->GetInputNum() || flag_out == kernel_build_info->GetOutputNum()) { + return false; + } + return true; +} + std::vector> FilterRaisedOrReducePrecisionMatchedKernelInfo( const CNodePtr &cnode, const std::vector> &kernel_info_list, bool *precision_reduce) { std::vector> filtered_kernel_info_list; - std::map> kernel_match_datatype_idx; - std::map> kernel_support_datatype; - std::vector node_mix_precision_datatype_index; - std::vector node_mix_precision_datatype; + const std::map raise_map = {{kNumberTypeFloat16, kNumberTypeFloat32}}; + const std::map reduce_map = {{kNumberTypeInt64, kNumberTypeInt32}, + {kNumberTypeFloat, kNumberTypeFloat16}, + {kNumberTypeFloat32, kNumberTypeFloat16}}; + // raise precision for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { - std::vector support_indexes; - std::vector support_datatypes; MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); - AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype, - &support_datatypes, &node_mix_precision_datatype_index); - kernel_match_datatype_idx[info_index] = support_indexes; - kernel_support_datatype[info_index] = support_datatypes; + if (TagRaiseReduce(kernel_info_list[info_index], cnode, raise_map)) { + filtered_kernel_info_list.push_back(kernel_info_list[info_index]); + } + } + + if (!filtered_kernel_info_list.empty()) { + *precision_reduce = false; + return filtered_kernel_info_list; + } + + // reduce precision + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_ENABLE_REDUCE_PRECISION)) { + for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) { + MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]); + if (TagRaiseReduce(kernel_info_list[info_index], cnode, reduce_map)) { + filtered_kernel_info_list.push_back(kernel_info_list[info_index]); + } + } + } + if (!filtered_kernel_info_list.empty()) { + *precision_reduce = true; } - PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype, - &kernel_match_datatype_idx, precision_reduce); - std::transform( - kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list), - [&](const std::pair> &matched_idx) -> std::shared_ptr { - return kernel_info_list[matched_idx.first]; - }); return filtered_kernel_info_list; } } // namespace