!3196 fix precision error with fp16 input on pynative mode

Merge pull request !3196 from chujinjin/fix_precision_error_with_fp16_input
This commit is contained in:
mindspore-ci-bot 2020-07-20 09:09:28 +08:00 committed by Gitee
commit ab53809f2c
3 changed files with 9 additions and 3 deletions

View File

@ -167,7 +167,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
} }
} // namespace } // namespace
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) { const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type,
const TypeId &type_id) {
MS_EXCEPTION_IF_NULL(trans_data); MS_EXCEPTION_IF_NULL(trans_data);
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data);
MS_EXCEPTION_IF_NULL(ori_build_info); MS_EXCEPTION_IF_NULL(ori_build_info);
@ -176,6 +177,10 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
builder->SetInputReshapeType({reshape_type}); builder->SetInputReshapeType({reshape_type});
builder->SetOutputReshapeType({reshape_type}); builder->SetOutputReshapeType({reshape_type});
builder->SetOutputsFormat({output_format}); builder->SetOutputsFormat({output_format});
if (type_id != kTypeUnknown) {
builder->SetOutputsDeviceType({type_id});
builder->SetInputsDeviceType({type_id});
}
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get());
} }

View File

@ -86,7 +86,8 @@ class OpFinder {
using OpFinderPtr = std::shared_ptr<OpFinder>; using OpFinderPtr = std::shared_ptr<OpFinder>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {}); const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {},
const TypeId &type_id = kTypeUnknown);
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const bool need_padding, const std::string &op_name); const bool need_padding, const std::string &op_name);

View File

@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
if (origin_format != cur_format && cur_shape.size() > 1) { if (origin_format != cur_format && cur_shape.size() > 1) {
auto kernel_select = std::make_shared<KernelSelect>(); auto kernel_select = std::make_shared<KernelSelect>();
final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(cur_format, origin_format, final_node); RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type);
final_index = 0; final_index = 0;
MS_EXCEPTION_IF_NULL(final_node); MS_EXCEPTION_IF_NULL(final_node);
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();