diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc index ddb01bde931..66c77dfec3f 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc @@ -30,6 +30,10 @@ namespace { void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { MS_EXCEPTION_IF_NULL(cnode_ptr); MS_EXCEPTION_IF_NULL(graph); + if (AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) || + AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) { + return; + } std::vector plant_inputs; std::vector dyn_input_sizes; plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc index f4a57b0baac..207407436be 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc @@ -26,22 +26,19 @@ namespace mindspore { namespace opt { namespace { -AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf, - std::unordered_map *transed_nodes) { +AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf) { MS_EXCEPTION_IF_NULL(tuple_anf); MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(transed_nodes); if (!AnfAlgo::IsTupleOutput(tuple_anf)) { return tuple_anf; } - auto transed_node_it = transed_nodes->find(tuple_anf); - if (transed_node_it != transed_nodes->end()) { - return transed_node_it->second; - } auto kernel_graph = graph->cast(); + if (kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf)) { + return kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf); + } auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf); - (*transed_nodes)[tuple_anf] = make_tuple; + kernel_graph->InsertTupleParameterToMakeTupleMap(tuple_anf, make_tuple); // replace graph inputs if input is a parameter kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple); return make_tuple; @@ -61,7 +58,6 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - std::unordered_map transed_nodes; if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode); MS_EXCEPTION_IF_NULL(real_input); @@ -77,7 +73,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func const auto &input = cnode->inputs()[i]; if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) && !AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) { - cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes)); + cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input)); cnode_input_changed = true; } } diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 08b23f3d673..277b4a6ffb8 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -523,12 +523,22 @@ void AscendControlParser::LinkParentGraph(NotNull kg, const CNod } } +void AscendControlParser::AttachOriginalInputsToGraph(NotNull graph, + const std::vector orig_inputs) { + std::vector make_tuple_inputs = { + mindspore::NewValueNode(std::make_shared(prim::kPrimMakeTuple->name()))}; + std::copy(orig_inputs.begin(), orig_inputs.end(), std::back_inserter(make_tuple_inputs)); + auto make_tuple = graph->NewCNode(make_tuple_inputs); + + InsertDependToGraph(graph, NOT_NULL(make_tuple)); +} + void AscendControlParser::RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, const NotNull *> memo) { MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); // 1 get kernel graph - const std::vector &origin_inputs = cur_node->inputs(); + std::vector origin_inputs = cur_node->inputs(); if (kCNodeCallArg >= origin_inputs.size()) { MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); } @@ -555,6 +565,8 @@ void AscendControlParser::RecurseCall(NotNull kg, NotNullset_inputs(new_inputs); cur_node->set_abstract(nullptr); AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue>({call_kg}), cur_node.get()); + origin_inputs.assign(origin_inputs.begin() + kCNodeCallArg + 1, origin_inputs.end()); + AttachOriginalInputsToGraph(kg, origin_inputs); MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); } @@ -587,11 +599,13 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNull origin_inputs; + std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); child_graphs.push_back(branch_fg); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); + AttachOriginalInputsToGraph(kg, origin_inputs); } std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); @@ -635,11 +649,13 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull for (size_t i = 0; i < branch_partial.size(); ++i) { // 3.1 branch kernel graph and args KernelGraphPtr branch_fg; - std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + std::vector origin_inputs; + std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); child_graphs.push_back(branch_fg); // 3.2 recurse sub graph CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); + AttachOriginalInputsToGraph(kg, origin_inputs); } new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); cur_node->set_inputs(new_switch_inputs); diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.h b/mindspore/ccsrc/backend/session/ascend_control_parser.h index 555de416229..d7074b6bbb3 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.h +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.h @@ -76,6 +76,7 @@ class AscendControlParser { static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); static std::vector RecurseGraph(NotNull graph, const NotNull *> memo); + static void AttachOriginalInputsToGraph(NotNull graph, const std::vector orig_inputs); }; class AscendControlParser::ReferenceCounter { public: diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index c2cce88b70e..46e38b4ac6e 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -162,6 +162,19 @@ class KernelGraph : public FuncGraph { void set_child_graph_result(const std::vector &child_graph_result) { child_graph_result_ = child_graph_result; } + void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) { + if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { + return; + } + tuple_parameter_to_make_tuple_map_[param] = make_tuple; + } + AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr ¶m) { + if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { + return tuple_parameter_to_make_tuple_map_[param]; + } else { + return nullptr; + } + } private: // remove value node form graph @@ -229,6 +242,7 @@ class KernelGraph : public FuncGraph { std::unordered_map>> internal_outputs_to_front_map_; std::unordered_map> internal_outputs_tensor_map_; uint32_t current_epoch_; + std::unordered_map tuple_parameter_to_make_tuple_map_; }; } // namespace session using KernelGraphPtr = std::shared_ptr;