forked from mindspore-Ecosystem/mindspore
!7217 reduce or raise precision restructure
Merge pull request !7217 from liubuyu/op_support
This commit is contained in:
commit
84f66ef5b8
|
@ -16,23 +16,23 @@
|
|||
|
||||
#include "runtime/device/ascend/kernel_select_ascend.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "utils/ms_utils.h"
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "backend/kernel_compiler/kernel_query.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/kernel_query.h"
|
||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
|
@ -172,218 +172,6 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|||
}
|
||||
}
|
||||
|
||||
void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
|
||||
MS_EXCEPTION_IF_NULL(support_index);
|
||||
int index = kUnSupportMixedDataTypeIndex;
|
||||
switch (data_type) {
|
||||
case kNumberTypeFloat16:
|
||||
index = 0;
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
case kNumberTypeFloat:
|
||||
index = 1;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
support_index->push_back(index);
|
||||
}
|
||||
|
||||
void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index,
|
||||
std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
|
||||
MS_EXCEPTION_IF_NULL(support_datatype);
|
||||
auto data_type = kernel_build_info.GetInputDeviceType(input_index);
|
||||
support_datatype->push_back(data_type);
|
||||
AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
|
||||
}
|
||||
|
||||
void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index,
|
||||
std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
|
||||
MS_EXCEPTION_IF_NULL(support_datatype);
|
||||
auto data_type = kernel_build_info.GetOutputDeviceType(output_index);
|
||||
support_datatype->push_back(data_type);
|
||||
AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
|
||||
}
|
||||
|
||||
void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
|
||||
std::vector<int> *node_mix_precision_datatype_index,
|
||||
std::vector<TypeId> *node_mix_precision_datatype) {
|
||||
AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index);
|
||||
MS_EXCEPTION_IF_NULL(cur_input);
|
||||
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
|
||||
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
|
||||
AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index);
|
||||
node_mix_precision_datatype->push_back(input_origin_type);
|
||||
}
|
||||
|
||||
void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index,
|
||||
std::vector<int> *node_mix_precision_datatype_index,
|
||||
std::vector<TypeId> *node_mix_precision_datatype) {
|
||||
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
|
||||
auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
|
||||
AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index);
|
||||
node_mix_precision_datatype->push_back(output_origin_type);
|
||||
}
|
||||
|
||||
void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_index,
|
||||
const std::vector<TypeId> &node_mix_precision_datatype,
|
||||
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
|
||||
std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
|
||||
if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) {
|
||||
MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size "
|
||||
<< node_mix_precision_datatype.size();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
|
||||
if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) {
|
||||
MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size "
|
||||
<< kernel_support_datatypes.size();
|
||||
}
|
||||
}
|
||||
|
||||
bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
|
||||
const std::vector<TypeId> &node_mix_precision_datatype,
|
||||
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
|
||||
std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
|
||||
CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
|
||||
kernel_match_datatype_idx);
|
||||
for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
|
||||
if (node_mix_precision_datatype[i] == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
auto iter = kernel_match_datatype_idx->begin();
|
||||
while (iter != kernel_match_datatype_idx->end()) {
|
||||
if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
|
||||
auto find_iter = kernel_support_datatypes.find(iter->first);
|
||||
if (find_iter == kernel_support_datatypes.end()) {
|
||||
MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first;
|
||||
}
|
||||
if (i >= find_iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size();
|
||||
}
|
||||
if (node_mix_precision_datatype[i] != find_iter->second[i]) {
|
||||
iter = kernel_match_datatype_idx->erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto datatype_indexes = iter->second;
|
||||
if (i >= datatype_indexes.size()) {
|
||||
MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size();
|
||||
}
|
||||
if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) {
|
||||
iter = kernel_match_datatype_idx->erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
return !kernel_match_datatype_idx->empty();
|
||||
}
|
||||
|
||||
bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index,
|
||||
const std::vector<int> &node_mix_precision_datatype_index) {
|
||||
auto check_index_tmp = IntToSize(check_index);
|
||||
if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) {
|
||||
return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
|
||||
datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range";
|
||||
}
|
||||
|
||||
bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
|
||||
const std::vector<TypeId> &node_mix_precision_datatype,
|
||||
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
|
||||
std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
|
||||
CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
|
||||
kernel_match_datatype_idx);
|
||||
for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
|
||||
if (node_mix_precision_datatype[i] == kTypeUnknown) {
|
||||
continue;
|
||||
}
|
||||
auto iter = kernel_match_datatype_idx->begin();
|
||||
while (iter != kernel_match_datatype_idx->end()) {
|
||||
if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
|
||||
auto find_iter = kernel_support_datatypes.find(iter->first);
|
||||
if (find_iter == kernel_support_datatypes.end()) {
|
||||
MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first;
|
||||
}
|
||||
if (i >= find_iter->second.size()) {
|
||||
MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size();
|
||||
}
|
||||
if (node_mix_precision_datatype[i] != find_iter->second[i]) {
|
||||
iter = kernel_match_datatype_idx->erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto datatype_indexes = iter->second;
|
||||
if (i >= datatype_indexes.size()) {
|
||||
MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
|
||||
}
|
||||
if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) {
|
||||
iter = kernel_match_datatype_idx->erase(iter);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
return !kernel_match_datatype_idx->empty();
|
||||
}
|
||||
|
||||
void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info,
|
||||
std::vector<int> *support_indexes, std::vector<TypeId> *node_mix_precision_datatype,
|
||||
std::vector<TypeId> *support_datatypes,
|
||||
std::vector<int> *node_mix_precision_datatype_index) {
|
||||
MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
|
||||
bool add_node_datatype_flag = false;
|
||||
if (node_mix_precision_datatype->empty()) {
|
||||
add_node_datatype_flag = true;
|
||||
}
|
||||
for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
|
||||
AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes);
|
||||
if (add_node_datatype_flag) {
|
||||
AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype);
|
||||
}
|
||||
}
|
||||
// Check output data type
|
||||
for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
|
||||
AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes);
|
||||
if (add_node_datatype_flag) {
|
||||
AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
|
||||
const std::vector<TypeId> &node_mix_precision_datatype,
|
||||
const std::map<size_t, std::vector<TypeId>> &kernel_support_datatype,
|
||||
std::map<size_t, std::vector<int>> *kernel_match_datatype_idx, bool *precision_reduce) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
MS_EXCEPTION_IF_NULL(precision_reduce);
|
||||
std::map<size_t, std::vector<int>> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx;
|
||||
// raise precision
|
||||
bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
|
||||
kernel_support_datatype, kernel_match_datatype_idx);
|
||||
if (selected_ret) {
|
||||
*precision_reduce = false;
|
||||
return;
|
||||
}
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) {
|
||||
selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
|
||||
kernel_support_datatype, &kernel_match_datatype_idx_copy);
|
||||
}
|
||||
if (selected_ret) {
|
||||
*precision_reduce = true;
|
||||
*kernel_match_datatype_idx = kernel_match_datatype_idx_copy;
|
||||
}
|
||||
}
|
||||
|
||||
void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
|
||||
const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info,
|
||||
bool precision_reduce) {
|
||||
|
@ -434,30 +222,82 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
|
|||
return result;
|
||||
}
|
||||
|
||||
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
|
||||
const std::map<TypeId, TypeId> &type_map) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
size_t flag_in = 0;
|
||||
size_t flag_out = 0;
|
||||
for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) {
|
||||
auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
auto device_dtype = kernel_build_info->GetInputDeviceType(input_index);
|
||||
if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
auto iter = type_map.find(in_dtype);
|
||||
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
flag_in += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t output_index = 0; output_index < kernel_build_info->GetOutputNum(); ++output_index) {
|
||||
auto in_dtype = AnfAlgo::GetOutputInferDataType(cnode, output_index);
|
||||
auto device_dtype = kernel_build_info->GetOutputDeviceType(output_index);
|
||||
if (device_dtype == kNumberTypeFloat || device_dtype == kNumberTypeFloat32) {
|
||||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
auto iter = type_map.find(in_dtype);
|
||||
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
flag_out += 1;
|
||||
}
|
||||
}
|
||||
if (flag_in == kernel_build_info->GetInputNum() || flag_out == kernel_build_info->GetOutputNum()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo(
|
||||
const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list,
|
||||
bool *precision_reduce) {
|
||||
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_kernel_info_list;
|
||||
std::map<size_t, std::vector<int>> kernel_match_datatype_idx;
|
||||
std::map<size_t, std::vector<TypeId>> kernel_support_datatype;
|
||||
std::vector<int> node_mix_precision_datatype_index;
|
||||
std::vector<TypeId> node_mix_precision_datatype;
|
||||
const std::map<TypeId, TypeId> raise_map = {{kNumberTypeFloat16, kNumberTypeFloat32}};
|
||||
const std::map<TypeId, TypeId> reduce_map = {{kNumberTypeInt64, kNumberTypeInt32},
|
||||
{kNumberTypeFloat, kNumberTypeFloat16},
|
||||
{kNumberTypeFloat32, kNumberTypeFloat16}};
|
||||
// raise precision
|
||||
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
|
||||
std::vector<int> support_indexes;
|
||||
std::vector<TypeId> support_datatypes;
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
|
||||
AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype,
|
||||
&support_datatypes, &node_mix_precision_datatype_index);
|
||||
kernel_match_datatype_idx[info_index] = support_indexes;
|
||||
kernel_support_datatype[info_index] = support_datatypes;
|
||||
if (TagRaiseReduce(kernel_info_list[info_index], cnode, raise_map)) {
|
||||
filtered_kernel_info_list.push_back(kernel_info_list[info_index]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!filtered_kernel_info_list.empty()) {
|
||||
*precision_reduce = false;
|
||||
return filtered_kernel_info_list;
|
||||
}
|
||||
|
||||
// reduce precision
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (context_ptr->get_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION)) {
|
||||
for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
|
||||
if (TagRaiseReduce(kernel_info_list[info_index], cnode, reduce_map)) {
|
||||
filtered_kernel_info_list.push_back(kernel_info_list[info_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!filtered_kernel_info_list.empty()) {
|
||||
*precision_reduce = true;
|
||||
}
|
||||
PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype,
|
||||
&kernel_match_datatype_idx, precision_reduce);
|
||||
std::transform(
|
||||
kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list),
|
||||
[&](const std::pair<size_t, std::vector<int>> &matched_idx) -> std::shared_ptr<kernel::KernelBuildInfo> {
|
||||
return kernel_info_list[matched_idx.first];
|
||||
});
|
||||
return filtered_kernel_info_list;
|
||||
}
|
||||
} // namespace
|
||||
|
|
Loading…
Reference in New Issue