!7637 raise reduce precision bug fix

Merge pull request !7637 from liubuyu/op_support
This commit is contained in:
mindspore-ci-bot 2020-10-23 13:51:38 +08:00 committed by Gitee
commit 07c8a6114e
1 changed files with 21 additions and 8 deletions

View File

@ -224,10 +224,10 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
const std::map<TypeId, TypeId> &type_map) {
// filte kernel info that unsupported raise or reduce datatype
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(kernel_build_info);
size_t flag_in = 0;
size_t flag_out = 0;
bool flag = false;
for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) {
auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
auto device_dtype = kernel_build_info->GetInputDeviceType(input_index);
@ -235,11 +235,17 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
device_dtype = kNumberTypeFloat32;
}
auto iter = type_map.find(in_dtype);
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
if (iter == type_map.end() && in_dtype != device_dtype) {
return false;
}
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
// or infer dtype not equal kernel info dtype, return false
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
return false;
}
if (iter == type_map.end() && in_dtype != device_dtype) {
flag_in += 1;
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
flag = true;
}
}
@ -250,15 +256,22 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
device_dtype = kNumberTypeFloat32;
}
auto iter = type_map.find(in_dtype);
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
if (iter == type_map.end() && in_dtype != device_dtype) {
return false;
}
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
// or infer dtype not equal kernel info dtype, return false
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
return false;
}
if (iter == type_map.end() && in_dtype != device_dtype) {
flag_out += 1;
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
flag = true;
}
}
if (flag_in == kernel_build_info->GetInputNum() || flag_out == kernel_build_info->GetOutputNum()) {
return false;
if (flag) {
auto node_name = AnfAlgo::GetCNodeName(cnode);
MS_LOG(WARNING) << "node:[" << node_name << "]reduce precision from int64 to int32";
}
return true;
}