From 1f809f50e5c626677dcc4c730914ce2e79d68490 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Sat, 18 Jul 2020 15:15:04 +0800 Subject: [PATCH] fix precision error with fp16 input on PyNative mode --- mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc | 7 ++++++- mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h | 3 ++- .../ascend/format_type/deal_ref_trans_and_cast.cc | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 9e1f6234b9..c10f8ebecc 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -167,7 +167,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const } } // namespace void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type) { + const AnfNodePtr &trans_data, const std::vector &reshape_type, + const TypeId &type_id) { MS_EXCEPTION_IF_NULL(trans_data); auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); MS_EXCEPTION_IF_NULL(ori_build_info); @@ -176,6 +177,10 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & builder->SetInputReshapeType({reshape_type}); builder->SetOutputReshapeType({reshape_type}); builder->SetOutputsFormat({output_format}); + if (type_id != kTypeUnknown) { + builder->SetOutputsDeviceType({type_id}); + builder->SetInputsDeviceType({type_id}); + } AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), trans_data.get()); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h index cb308a09a0..31b68b16d5 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -86,7 +86,8 @@ class OpFinder { using OpFinderPtr = std::shared_ptr; void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); + const AnfNodePtr &trans_data, const std::vector &reshape_type = {}, + const TypeId &type_id = kTypeUnknown); CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, const bool need_padding, const std::string &op_name); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc index 3dbe2d9f8a..886fea41b6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -107,7 +107,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP if (origin_format != cur_format && cur_shape.size() > 1) { auto kernel_select = std::make_shared(); final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(cur_format, origin_format, final_node); + RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); final_index = 0; MS_EXCEPTION_IF_NULL(final_node); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();