From e1a3c39face90f2f9ea23800464d6a7a886d295c Mon Sep 17 00:00:00 2001 From: lvliang Date: Mon, 10 Aug 2020 09:07:32 +0800 Subject: [PATCH] fix-bug-avoid-multi-attr-value-be-eliminated-in-pynative-mode --- .../optimizer/pass/convert_tuple_output_to_maketuple.cc | 9 ++++++++- .../frontend/optimizer/irpass/arithmetic_simplify.cc | 3 +++ mindspore/ccsrc/pipeline/pynative/pynative_execute.cc | 5 +++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc index 8bdc234e81f..f4a57b0baac 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc @@ -62,7 +62,14 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); std::unordered_map transed_nodes; - if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { + auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode); + MS_EXCEPTION_IF_NULL(real_input); + if (!real_input->isa() && !real_input->isa()) { + return nullptr; + } + } + if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { return nullptr; } bool cnode_input_changed = false; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index ecdc44d25a3..3a0e6f6e14b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -41,6 +41,9 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr } // Prim Eliminate (identity) MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x); + if (MsContext::GetInstance()->execution_mode() == kPynativeMode) { + return nullptr; + } // ConstantDuplicateMul auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr { diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index b918de09424..b83306a922c 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -393,9 +393,7 @@ bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_i ValuePtr value = parse::data_converter::PyDataToValue(input_object); MS_EXCEPTION_IF_NULL(value); auto input_name = input_names_vec[input_index]; - op_prim->BeginRecordAddAttr(); op_prim->AddAttr(input_name, value); - op_prim->EndRecordAddAttr(); return true; } return false; @@ -499,6 +497,8 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *te opt::ConstInputToAttrInfoRegister reg; bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); + + op_prim->BeginRecordAddAttr(); size_t input_num = op_run_info->op_inputs.size(); for (size_t index = 0; index < input_num; ++index) { // convert const input to attr @@ -513,6 +513,7 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *te std::vector new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); } + op_prim->EndRecordAddAttr(); } void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors) {