forked from mindspore-Ecosystem/mindspore
!7637 raise reduce precision bug fix
Merge pull request !7637 from liubuyu/op_support
This commit is contained in:
commit
07c8a6114e
|
@ -224,10 +224,10 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
|
|||
|
||||
bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build_info, const CNodePtr &cnode,
|
||||
const std::map<TypeId, TypeId> &type_map) {
|
||||
// filte kernel info that unsupported raise or reduce datatype
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
size_t flag_in = 0;
|
||||
size_t flag_out = 0;
|
||||
bool flag = false;
|
||||
for (size_t input_index = 0; input_index < kernel_build_info->GetInputNum(); ++input_index) {
|
||||
auto in_dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
|
||||
auto device_dtype = kernel_build_info->GetInputDeviceType(input_index);
|
||||
|
@ -235,11 +235,17 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
|
|||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
auto iter = type_map.find(in_dtype);
|
||||
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
|
||||
// or infer dtype not equal kernel info dtype, return false
|
||||
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
flag_in += 1;
|
||||
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
|
||||
flag = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,15 +256,22 @@ bool TagRaiseReduce(const std::shared_ptr<kernel::KernelBuildInfo> &kernel_build
|
|||
device_dtype = kNumberTypeFloat32;
|
||||
}
|
||||
auto iter = type_map.find(in_dtype);
|
||||
// if infer dtype node in type_map and the infer dtype not equal kernel info dtype, return false
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
// infer dtype in type_map, but can not find dst dtype that supported raise or reduce,
|
||||
// or infer dtype not equal kernel info dtype, return false
|
||||
if (iter != type_map.end() && iter->second != device_dtype && in_dtype != device_dtype) {
|
||||
return false;
|
||||
}
|
||||
if (iter == type_map.end() && in_dtype != device_dtype) {
|
||||
flag_out += 1;
|
||||
if (in_dtype == kNumberTypeInt64 && device_dtype == kNumberTypeInt32) {
|
||||
flag = true;
|
||||
}
|
||||
}
|
||||
if (flag_in == kernel_build_info->GetInputNum() || flag_out == kernel_build_info->GetOutputNum()) {
|
||||
return false;
|
||||
if (flag) {
|
||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||
MS_LOG(WARNING) << "node:[" << node_name << "]reduce precision from int64 to int32";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue