From f868a2855f1e283150b720bd3e0b1a7ff3827de4 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Tue, 26 May 2020 09:17:38 +0800 Subject: [PATCH] Insert assign nodes for linking sub graph Signed-off-by: zhoufeng --- .../device/ascend/ascend_label_assign.cc | 19 +- mindspore/ccsrc/device/kernel_runtime.h | 2 +- .../ccsrc/session/ascend_control_parser.cc | 374 +++++++++++++----- .../ccsrc/session/ascend_control_parser.h | 34 +- mindspore/ccsrc/session/ascend_session.cc | 132 +++++-- mindspore/ccsrc/session/ascend_session.h | 8 +- mindspore/ccsrc/session/kernel_graph.cc | 154 ++++---- mindspore/ccsrc/session/kernel_graph.h | 11 +- mindspore/ccsrc/session/session_basic.cc | 77 ++-- mindspore/ccsrc/vm/transform.cc | 3 +- 10 files changed, 546 insertions(+), 268 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc index db68516500b..2973e5529cf 100644 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc @@ -28,6 +28,9 @@ namespace device { namespace ascend { static void UpdateLabelGoto(NotNull node) { + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { + return; + } if (node->size() <= kLabelGotoLabelId) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); } @@ -42,6 +45,9 @@ static void UpdateLabelGoto(NotNull node) { } static void UpdateLabelSwitch(NotNull node) { + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, node)) { + return; + } if (node->size() <= kLabelGotoLabelId) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); } @@ -69,9 +75,12 @@ static void AssignLabelForLabelSet(NotNull if (memo->find(graph.get()) != memo->end()) { return; } + memo->insert(graph.get()); MS_LOG(INFO) << "Assign label for " << graph->ToString(); - auto nodes = TopoSort(graph->get_return()); + graph->SetExecOrderByDefault(); + auto nodes = graph->execution_order(); + for (auto &node : nodes) { if (!node->isa()) { continue; @@ -97,9 +106,15 @@ static void AssignLabelForGotoSwitch(NotNullfind(graph.get()) != memo->end()) { return; } + memo->insert(graph.get()); MS_LOG(INFO) << "Process label goto/switch for " << graph->ToString(); - auto nodes = TopoSort(graph->get_return()); + graph->SetExecOrderByDefault(); + auto nodes = graph->execution_order(); + auto end_goto = graph->get_end_goto(); + if (end_goto != nullptr) { + nodes.push_back(end_goto); + } for (auto &node : nodes) { if (!node->isa()) { continue; diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index bf44698b896..02f87671e9d 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -53,6 +53,7 @@ class KernelRuntime { virtual bool GenTask(const session::KernelGraph *graph); bool LaunchKernel(const session::KernelGraph *graph); virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); + virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); #ifdef ENABLE_DUMP_E2E DumpConfPtr GetDumpConf(); @@ -67,7 +68,6 @@ class KernelRuntime { TypeId type_id) = 0; virtual bool SyncStream() = 0; void AssignStaticMemory(session::KernelGraph *graph); - void AssignStaticMemoryValueNode(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph); void ReuseAssignDynamicMemory(session::KernelGraph *graph); void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 66587678144..416ea49e634 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -22,49 +22,78 @@ namespace mindspore { namespace session { -static VectorRef GetCallArgs(std::vector::iterator iter_begin, std::vector::iterator iter_end) { - VectorRef call_args; - for (auto iter = iter_begin; iter != iter_end; ++iter) { - if (utils::isa(*iter)) { - call_args.push_back(GetValueNode(*iter)); - } else { - call_args.push_back(*iter); +void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { + for (auto &iter : graph_id_map) { + auto &kg = iter.second; + MS_EXCEPTION_IF_NULL(kg); + auto real_inputs = kg->real_inputs(); + for (auto &it : real_inputs) { + auto ¶meter = it.first; + auto &args = it.second; + for (auto &arg : args) { + MS_EXCEPTION_IF_NULL(arg); + if (arg->isa()) { + MS_LOG(INFO) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() + << ", arg:" << arg->DebugString(); + continue; + } + auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); + if (target_graph_iter == graph_id_map.end()) { + MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; + } + InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); + } } } - return call_args; } void AscendControlParser::LinkGraph(NotNull kg) { std::set memo; - ProcessKernelGraph(kg, nullptr, nullptr, {}, NOT_NULL(&memo)); + ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); + std::map graph_id_map; + for (auto &g : memo) { + if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { + MS_LOG(EXCEPTION) << "Two graph has same graph id " << g->graph_id() + << ", graph: " << graph_id_map[g->graph_id()]->ToString() << " " << g->ToString(); + } + graph_id_map[g->graph_id()] = g; + } + ChildGraphDataAssign(graph_id_map); +} + +CNodePtr AscendControlParser::GetNextRealKernel(std::vector list, size_t start) { + for (size_t i = start; i < list.size() - 1; ++i) { + if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { + return list[i]; + } + } + return nullptr; } NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, const VectorRef &args, + const CNodePtr &last_label, NotNull *> memo) { MS_LOG(INFO) << "Start process KernelGraph " << kg->ToString(); - // 0. recursive condition + + // 1. recursive condition if (memo->find(kg) != memo->end()) { MS_LOG(INFO) << "KernelGraph has beed processed: " << kg->ToString(); return NOT_NULL(kg->get_start_label()); } + memo->insert(kg.get()); // 2. args replace placeholder - LinkParentGraph(kg, last_node, last_label, args); + LinkParentGraph(kg, last_node, last_label, memo); + // 3. topological sort - std::vector nodes = GetCNodes(TopoSort(kg->get_return())); + kg->SetExecOrderByDefault(); + std::vector nodes = kg->execution_order(); if (nodes.empty()) { MS_LOG(EXCEPTION) << "KernelGraph " << kg->ToString() << " has no cnodes!"; } // 4. insert first_label auto start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - for (auto node : nodes) { - if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { - InsertControlDependToGraph(kg, NOT_NULL(start_label), NOT_NULL(node)); - break; - } - } - + MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); kg->set_start_label(start_label); // 5. traverse for (size_t i = 0; i < nodes.size(); ++i) { @@ -79,17 +108,19 @@ NotNull AscendControlParser::ProcessKernelGraph(NotNullinput(kCNodeCallArg); if (IsValueNode(arg)) { - RecurseCall(kg, NOT_NULL(cnode), (i + 1 < nodes.size() ? nodes[i + 1] : nullptr), memo); + RecurseCall(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); } else if (!arg->isa()) { MS_LOG(EXCEPTION) << "Unknown type call node " << cnode->DebugString(); } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitch)) { auto arg_cnode = arg->cast(); - cnode->set_inputs(cnode->inputs()); - RecurseSwitch(kg, NOT_NULL(cnode), memo); + MS_EXCEPTION_IF_NULL(arg_cnode); + cnode->set_inputs(arg_cnode->inputs()); + RecurseSwitch(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); } else if (IsPrimitiveCNode(arg->cast(), prim::kPrimSwitchLayer)) { auto arg_cnode = arg->cast(); - cnode->set_inputs(cnode->inputs()); - RecurseSwitchLayer(kg, NOT_NULL(cnode), memo); + MS_EXCEPTION_IF_NULL(arg_cnode); + cnode->set_inputs(arg_cnode->inputs()); + RecurseSwitchLayer(kg, NOT_NULL(cnode), GetNextRealKernel(nodes, i + 1), memo); } } @@ -97,16 +128,6 @@ NotNull AscendControlParser::ProcessKernelGraph(NotNull AscendControlParser::GetCNodes(const std::vector &in) { - std::vector out; - for (auto &node : in) { - if (node->isa()) { - out.push_back(node->cast()); - } - } - return out; -} - void AscendControlParser::InsertDependToGraph(NotNull kg, NotNull attch_node) { std::vector inputs = {NewValueNode(std::make_shared("depend"))}; auto return_node = kg->get_return(); @@ -128,11 +149,7 @@ void AscendControlParser::InsertControlDependToGraph(NotNull kg, } void AscendControlParser::LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label, const VectorRef &args) { - if (from_graph_call_node != nullptr) { - SetSubGraphInput(kg, NOT_NULL(from_graph_call_node), args); - } - + const CNodePtr &last_label, NotNull *> memo) { auto origin_return = kg->get_return(); std::vector origin_return_inputs = origin_return->inputs(); // if entry graph, replace return with make_tuple @@ -146,7 +163,8 @@ void AscendControlParser::LinkParentGraph(NotNull kg, const CNod // else replace return with label_goto auto label_goto = kg->NewCNode({std::make_shared(std::make_shared(kLabelGotoOpName)), last_label}); - InsertDependToGraph(kg, NOT_NULL(label_goto)); + MS_LOG(INFO) << "Insert end goto " << label_goto->DebugString() << " to " << kg->ToString(); + kg->set_end_goto(label_goto); } } @@ -157,13 +175,14 @@ void AscendControlParser::RecurseCall(NotNull kg, NotNullinputs(); std::vector new_inputs = {std::make_shared(std::make_shared(kLabelGotoOpName))}; - auto call_args = GetCallArgs(origin_inputs.begin() + 1, origin_inputs.end()); if (!IsValueNode(origin_inputs[kCNodeCallArg])) { MS_LOG(WARNING) << "Node " << cur_node->DebugString(10) << " index " << kCNodeCallArg << " is not a ValueNode"; return; } // 2 return label auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " call node " + << cur_node->DebugString(); // 3 add depend relationship InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); if (next_node != nullptr && next_node != kg->get_return()) { @@ -173,7 +192,7 @@ void AscendControlParser::RecurseCall(NotNull kg, NotNullset_input(kCNodePrim, new_inputs[kCNodePrim]); // 5 recurse sub graph - CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, call_args, memo); + CNodePtr sub_label = ProcessKernelGraph(NOT_NULL(call_kg), cur_node, back_label, memo); new_inputs.push_back(sub_label); new_inputs.insert(new_inputs.end(), origin_inputs.begin(), origin_inputs.end()); cur_node->set_inputs(new_inputs); @@ -182,32 +201,37 @@ void AscendControlParser::RecurseCall(NotNull kg, NotNull kg, NotNull cur_node, - NotNull *> memo) { + const CNodePtr &next_node, NotNull *> memo) { MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); if (cur_node->size() < kCNodeSwitchLength) { MS_LOG(EXCEPTION) << "Inputs of apply node must more than " << kCNodeSwitchLength; } // 1 return label - auto back_label = kg->NewCNode({std::make_shared(prim::kPrimLabelSet)}); - // 2 recurse sub graph + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_LOG(INFO) << "Insert back label " << back_label->DebugString() << " to " << kg->ToString() << " switch node " + << cur_node->DebugString(); + // 2 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + // 3 recurse sub graph auto origin_switch_inputs = cur_node->inputs(); std::vector new_switch_inputs = { std::make_shared(std::make_shared(kLabelSwitchOpName)), origin_switch_inputs[kCNodeSwitchCond]}; for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { - // 2.1 branch kernel graph and args + // 3.1 branch kernel graph and args CNodePtr partial; KernelGraphPtr branch_fg; - VectorRef call_args; - std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); - // 2.2 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - // 2.3 recurse sub graph - CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); + std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + // 3.2 recurse sub graph + CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); } std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); + new_switch_inputs.insert(new_switch_inputs.end(), origin_switch_inputs.begin(), origin_switch_inputs.end()); cur_node->set_inputs(new_switch_inputs); cur_node->set_abstract(nullptr); @@ -215,7 +239,7 @@ void AscendControlParser::RecurseSwitch(NotNull kg, NotNull kg, NotNull cur_node, - NotNull *> memo) { + const CNodePtr &next_node, NotNull *> memo) { MS_LOG(INFO) << "process switch node " << cur_node->DebugString(); if (cur_node->size() < kCNodeSwitchLayerLength) { @@ -229,21 +253,24 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull } auto branch_partial = utils::cast(branch_tuple)->inputs(); // 1 return label - auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSwitchOpName))}); - // 2 recurse sub graph + auto back_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + // 2 add depend relationship + InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); + if (next_node != nullptr && next_node != kg->get_return()) { + InsertControlDependToGraph(kg, NOT_NULL(back_label), NOT_NULL(next_node)); + } + // 3 recurse sub graph auto origin_switch_inputs = cur_node->inputs(); - std::vector new_switch_inputs = {std::make_shared(prim::kPrimLabelSwitch), - origin_switch_inputs[kCNodeSwitchCond]}; + std::vector new_switch_inputs = { + std::make_shared(std::make_shared(kLabelSwitchOpName)), + origin_switch_inputs[kCNodeSwitchCond]}; for (size_t i = 0; i < branch_partial.size(); ++i) { - // 2.1 branch kernel graph and args + // 3.1 branch kernel graph and args CNodePtr partial; KernelGraphPtr branch_fg; - VectorRef call_args; - std::tie(partial, branch_fg, call_args) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); - // 2.2 add depend relationship - InsertControlDependToGraph(kg, cur_node, NOT_NULL(back_label)); - // 2.3 recurse sub graph - CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, call_args, memo); + std::tie(partial, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); + // 3.2 recurse sub graph + CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); new_switch_inputs.push_back(branch_label); } new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); @@ -252,7 +279,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull kg, NotNull MS_LOG(INFO) << "success process switch layer " << cur_node->DebugString(); } -std::tuple AscendControlParser::ParsePartial(NotNull node) { +std::tuple AscendControlParser::ParsePartial(NotNull node) { if (!node.get()->isa()) { MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString(); } @@ -263,9 +290,8 @@ std::tuple AscendControlParser::ParsePartia } auto partial_inputs = partial_cnode->inputs(); auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); - auto call_args = GetCallArgs(partial_inputs.begin() + kCNodePartialFunc + 1, partial_inputs.end()); - return {partial_cnode, branch_kg, call_args}; + return {partial_cnode, branch_kg}; } void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNull from, @@ -289,31 +315,199 @@ void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNul InsertDependToGraph(kg, NOT_NULL(assign_node)); } -size_t AscendControlParser::SetChildGraphInput(NotNull kg, NotNull node, - size_t input_index) { - auto output_num = AnfAlgo::GetOutputTensorNum(node); - if (output_num > 1 && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { - return input_index + output_num; +NotNull AscendControlParser::GetRealInput(NotNull from_graph, + NotNull to_graph, NotNull param) { + std::set args_list = to_graph->GetRealInput(param); + for (auto arg : args_list) { + if (arg->func_graph() == from_graph.get()) { + return NOT_NULL(arg); + } } - - auto &graph_inputs = kg->inputs(); - if (input_index >= graph_inputs.size()) { - MS_LOG(EXCEPTION) << "input_index " << input_index << " out of range size " << graph_inputs.size(); - } - auto backend_parameter = graph_inputs[input_index]; - if (node.get()->isa()) { - MS_EXCEPTION_IF_NULL(backend_parameter); - MS_LOG(INFO) << "Reuse node [" << node->DebugString() << "], old node[" << backend_parameter->DebugString() - << "] will be replaced."; - kg->ReplaceNode(backend_parameter, node); - return input_index; - } - InsertAssignToGraph(kg, node, NOT_NULL(backend_parameter)); - return input_index + 1; + MS_LOG(EXCEPTION) << to_graph->ToString() << " input " << param->DebugString() << " not from " + << from_graph->ToString(); } -void AscendControlParser::SetSubGraphInput(NotNull kg, NotNull from_graph_call_node, - const VectorRef &args) {} +void AscendControlParser::LinkArgsToParam(NotNull to_graph, NotNull target_graph, + NotNull arg, NotNull param) { + if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) { + MS_LOG(INFO) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " is a tuple"; + CNodePtr cnode_arg = arg.get()->cast(); + CNodePtr cnode_param = param.get()->cast(); + MS_EXCEPTION_IF_NULL(cnode_arg); + MS_EXCEPTION_IF_NULL(cnode_param); + if (cnode_arg->size() != cnode_param->size()) { + MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " size " << cnode_arg->size() << " but Param " + << param->DebugString() << " size " << cnode_param->size(); + } + + for (size_t i = 1; i < cnode_param->size(); ++i) { + LinkArgsToParam(to_graph, target_graph, NOT_NULL(cnode_arg->input(i)), NOT_NULL(cnode_param->input(i))); + } + } else if (arg->isa()) { + InsertAssignToGraph(target_graph, arg, param); + } else { + MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " unknown type."; + } +} + +void AscendControlParser::ExecutorValidate(NotNull root_graph) { + std::set memo; + (void)RecurseGraph(nullptr, nullptr, root_graph, NOT_NULL(&memo)); +} + +std::vector AscendControlParser::RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto, + NotNull graph, + NotNull *> memo) { + MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; + auto print_vector = [&](std::vector vec) -> void { + MS_LOG(INFO) << "graph:" << graph->graph_id() << "execution order"; + for (size_t i = 0; i < vec.size(); i++) { + MS_LOG(INFO) << "[" << i << "][" << vec[i]->DebugString() << "]"; + } + }; + if (memo->find(graph) != memo->end()) { + return {}; + } + memo->insert(graph.get()); + + graph->SetExecOrderByDefault(); + + std::vector cnodes = graph->execution_order(); + std::map label_map; + std::map> label_switch_map; + std::tie(label_map, label_switch_map) = GetLabelNode(cnodes); + std::vector execution_order; + + for (auto &node : cnodes) { + execution_order.push_back(node); + if (node == graph->get_end_goto()) { + continue; + } + + auto label_iter = + std::find_if(label_map.begin(), label_map.end(), + [node](const std::map::value_type iter) { return iter.second == node; }); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { + if (!CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) { + MS_LOG(EXCEPTION) << "Check label index fail"; + } + auto child_graph = graph->child_graph_order()[label_iter->first]; + if (child_graph == graph->parent_graph()) { + continue; + } + std::map child_label_map; + std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order()); + auto child_execution_order = + RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo); + execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { + std::vector label_list = label_switch_map.find(node)->second; + std::reverse(label_list.begin(), label_list.end()); + for (size_t i = 0; i < label_list.size(); ++i) { + if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) { + MS_LOG(EXCEPTION) << "Check label index fail"; + } + auto child_graph = graph->child_graph_order()[label_iter->first + i]; + if (child_graph == graph->parent_graph()) { + continue; + } + std::map child_label_map; + std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order()); + auto child_execution_order = + RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo); + execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); + } + } + } + graph->set_execution_order(execution_order); + print_vector(graph->execution_order()); + return execution_order; +} + +bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, + NotNull graph) { + // check index and child order size + if (graph->child_graph_order().size() <= static_cast(order_index)) { + MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " + << graph->child_graph_order().size() << " goto index " << order_index; + } + + if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) { + // check label_goto and start_label in child graph + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_label)) { + MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(cur_label); + MS_EXCEPTION_IF_NULL(primitive); + uint32_t label_goto_index = GetValue(primitive->GetAttr(kAttrLabelIndex)); + label_index = label_goto_index; + } + // get start_label_set_index of child graph + auto child_graph = graph->child_graph_order()[order_index]; + MS_EXCEPTION_IF_NULL(child_graph); + auto start_label_set = child_graph->get_start_label(); + if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) { + MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; + } + auto start_primitive = AnfAlgo::GetCNodePrimitive(start_label_set); + MS_EXCEPTION_IF_NULL(start_primitive); + uint32_t start_label_set_index = GetValue(start_primitive->GetAttr(kAttrLabelIndex)); + if (label_index != start_label_set_index) { + MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() + << " index " << start_label_set_index << " current child graph order : " << order_index; + return false; + } + return true; +} + +std::tuple, std::map>> AscendControlParser::GetLabelNode( + const std::vector &nodes) { + std::map label_map; + std::map> label_switch_map; + // record child graph + uint32_t index = 0; + for (auto &node : nodes) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { + label_map[index] = node; + ++index; + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { + if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) { + MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; + } + auto primitive = AnfAlgo::GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(primitive); + std::vector label_list = GetValue>(primitive->GetAttr(kAttrLabelSwitchList)); + label_switch_map.insert({node, label_list}); + for (size_t i = 0; i < label_list.size(); ++i) { + label_map[index] = node; + ++index; + } + } + } + return {label_map, label_switch_map}; +} + +void AscendControlParser::UpdateChildGraphOrder(NotNull kg) { + MS_LOG(INFO) << "graph id:" << kg->graph_id(); + kg->SetExecOrderByDefault(); + auto call_nodes = kg->FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); + std::vector child_graph_order; + for (auto &call_node : call_nodes) { + MS_EXCEPTION_IF_NULL(call_node); + auto call_child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); + for (const auto &child_graph : call_child_graphs) { + MS_EXCEPTION_IF_NULL(child_graph); + if (child_graph != kg->parent_graph()) { + child_graph->set_parent_graph(kg.get()); + } + child_graph_order.push_back(child_graph); + } + } + for (size_t i = 0; i < child_graph_order.size(); i++) { + MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order[i]->graph_id() << "]"; + } + kg->set_child_graph_order(child_graph_order); +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index ca215ef0c21..0f08d39c828 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -17,6 +17,7 @@ #define MINDSPORE_CCSRC_SESSION_ASCEND_CONTROL_PARSER_H #include +#include #include #include #include "session/kernel_graph.h" @@ -28,31 +29,44 @@ namespace session { class AscendControlParser { public: + static void ChildGraphDataAssign(const std::map &graph_id_map); static void LinkGraph(NotNull kg); static void InsertDependToGraph(NotNull kg, NotNull attch_node); static void InsertControlDependToGraph(NotNull kg, NotNull first_node, NotNull second_node); + static void ExecutorValidate(NotNull root_graph); + static void UpdateChildGraphOrder(NotNull kg); private: static NotNull ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, - const CNodePtr &last_label, const VectorRef &args, - NotNull *> memo); + const CNodePtr &last_label, NotNull *> memo); static void RecurseCall(NotNull kg, NotNull cur_node, const CNodePtr &next_node, NotNull *> memo); - static void RecurseSwitch(NotNull kg, NotNull cur_node, + static void RecurseSwitch(NotNull kg, NotNull cur_node, const CNodePtr &next_node, NotNull *> memo); - static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, + static void RecurseSwitchLayer(NotNull kg, NotNull cur_node, const CNodePtr &next_node, NotNull *> memo); - static std::vector GetCNodes(const std::vector &in); static void LinkParentGraph(NotNull kg, const CNodePtr &from_graph_call_node, - const CNodePtr &last_label, const VectorRef &args); - static void SetSubGraphInput(NotNull kg, NotNull from_graph_call_node, - const VectorRef &args); - static std::tuple ParsePartial(NotNull node); + const CNodePtr &last_label, NotNull *> memo); + static std::tuple ParsePartial(NotNull node); + + static void LinkArgsToParam(NotNull to_graph, NotNull target_graph, + NotNull arg, NotNull param); + static NotNull GetRealInput(NotNull from_graph, NotNull to_graph, + NotNull param); static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); - static size_t SetChildGraphInput(NotNull kg, NotNull node, size_t input_index); + + static CNodePtr GetNextRealKernel(std::vector list, size_t start); + + // root graph order + static std::tuple, std::map>> GetLabelNode( + const std::vector &nodes); + static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, + NotNull graph); + static std::vector RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto, + NotNull graph, NotNull *> memo); static constexpr size_t kCNodePrim = 0; static constexpr size_t kCNodeCallArg = 1; diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 9fe9fc9f4bc..b99c99443d2 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -177,10 +177,6 @@ std::vector> GetChildList(const KernelGraph &cur_graph, co for (size_t i = 0; i < cnodes.size(); i++) { if (AnfAlgo::CheckPrimitiveType(cnodes[i], prim::kPrimCall) && !AnfAlgo::IsSwitchCall(cnodes[i])) { auto call_kernel_graph = AnfAlgo::GetCallNodeKernelGraph(cnodes[i]); - // if graph is the true branch of while,no need split graph - if (call_kernel_graph.size() == 1 && call_kernel_graph[0] == cur_graph.parent_graph()) { - continue; - } auto prev_call_list = std::vector(cnodes.begin() + after_call_index, cnodes.begin() + i); auto call_list = std::vector(1, cnodes[i]); after_call_index = i + 1; @@ -195,10 +191,10 @@ std::vector> GetChildList(const KernelGraph &cur_graph, co // if a call has kernel input, it's a child graph split from ME, so these kernel input should be set into real input of // graph.For example, call input = (prim,graph,kernel1,kernel2),then real_input = [kernel1,kernel2] -void UpdateRealInput(KernelGraph *graph) { +static void UpdateRealInput(KernelGraph *graph) { auto call_nodes = graph->FindNodeByPrimitive(prim::kPrimCall); - auto bind_call_partial_with_parameter = [&](const std::vector ¶meters, - const std::vector &args, KernelGraph *child_graph) -> void { + auto bind_call_arg_with_parameter = [&](const std::vector ¶meters, + const std::vector &args, KernelGraph *child_graph) -> void { MS_EXCEPTION_IF_NULL(child_graph); MS_LOG(INFO) << "start bind parameter of child graph:" << child_graph->graph_id(); if (args.empty()) { @@ -208,8 +204,21 @@ void UpdateRealInput(KernelGraph *graph) { MS_LOG(EXCEPTION) << "graph:" << child_graph->graph_id() << " parameters size:" << parameters.size() << " and args size:" << args.size() << " not equal!"; } + child_graph->SetExecOrderByDefault(); for (size_t i = 0; i < parameters.size(); i++) { - MS_LOG(INFO) << "bind paramreter:" << parameters[i]->DebugString() << " ,arg:" << args[i]->DebugString(); + if (args[i] == parameters[i]) { + child_graph->SetRealInput(parameters[i], args[i]); + MS_LOG(INFO) << "Parameter and arg are same"; + continue; + } + // if arg is a parameter ,then reuse this parameter + if (args[i]->isa()) { + MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id() + << " reuse parameter:" << args[i]->DebugString() + << " of graph:" << AnfAlgo::GetGraphId(args[i].get()); + child_graph->ReplaceNode(parameters[i], args[i]); + continue; + } child_graph->SetRealInput(parameters[i], args[i]); } }; @@ -218,9 +227,10 @@ void UpdateRealInput(KernelGraph *graph) { auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node); if (child_graphs.size() == 1) { MS_EXCEPTION_IF_NULL(child_graphs[0]); - bind_call_partial_with_parameter( - child_graphs[0]->inputs(), std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()), - child_graphs[0].get()); + std::vector real_args = + std::vector(call_node->inputs().begin() + 2, call_node->inputs().end()); + std::vector child_inputs = child_graphs[0]->inputs(); + bind_call_arg_with_parameter(child_inputs, real_args, child_graphs[0].get()); call_node->set_inputs(std::vector(call_node->inputs().begin(), call_node->inputs().begin() + 2)); } else if (child_graphs.size() == 2) { auto get_partial_args = [&](size_t input_index) -> std::vector { @@ -237,8 +247,8 @@ void UpdateRealInput(KernelGraph *graph) { std::vector(partial_cnode->inputs().begin(), partial_cnode->inputs().begin() + 2)); return ret; }; - bind_call_partial_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); - bind_call_partial_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); + bind_call_arg_with_parameter(child_graphs[0]->inputs(), get_partial_args(2), child_graphs[0].get()); + bind_call_arg_with_parameter(child_graphs[1]->inputs(), get_partial_args(3), child_graphs[1].get()); } } } @@ -248,6 +258,11 @@ void RecurseToUpdateCallRealInput(KernelGraph *graph) { MS_LOG(INFO) << "start graph id:" << graph->graph_id(); graph->UpdateCallRealInput(); for (auto &child_graph : graph->child_graph_order()) { + if (child_graph == graph->parent_graph()) { + MS_LOG(INFO) << "Child graph:" << child_graph->graph_id() + << ",parent graph:" << graph->parent_graph()->graph_id(); + continue; + } RecurseToUpdateCallRealInput(child_graph.get()); } } @@ -265,31 +280,31 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL GraphId AscendSession::CompileGraph(NotNull func_graph) { MS_LOG(INFO) << "start"; auto graph = ConstructKernelGraph(func_graph); + MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // split switch SplitGraphs(graph); + MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // insert goto labels and label_sets LinkChildGraphs(NOT_NULL(graph)); + MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // resource initialize InitRuntimeResource(); // assign label AssignLabel(NOT_NULL(graph)); - if (!graph->executable()) { - return graph->graph_id(); - } - for (auto iter : graphs_) { - if (iter.second == graph) { - MS_LOG(INFO) << "Entry graph " << graph->ToString() << " graph id " << graph->graph_id(); - final_graph_id_ = graph->graph_id(); - } - MS_LOG(INFO) << "CompileChildGraph " << iter.second->ToString(); - CompileChildGraph(iter.second); - } + MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); + // recurse compile child graph + RecurseCompileGraph(graph); + MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); + // root graph valiate,include genearte execute order and so on + RootGraphExecutorValidate(NOT_NULL(graph)); + MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // adjust kernel AdjustKernel(graph); - // root graph valiate,include genearte execute order and so on - RootGraphExecutorValidate(graph.get()); + MS_LOG(INFO) << "graph input size:" << graph->inputs().size(); // assign stream AssignStream(graph); + // build kernel + BuildKernel(graph); // alloc mem MemoryAlloc(graph.get()); // task generate @@ -365,6 +380,7 @@ void AscendSession::BuildGraph(GraphId graph_id) { void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { MS_EXCEPTION_IF_NULL(child_graph); + MS_LOG(INFO) << "CompileChildGraph " << child_graph->ToString(); opt::AscendBackendIRFusionOptimization(child_graph); // select kernel build info SelectKernel(*child_graph); @@ -376,12 +392,14 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) { auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); runtime_instance->AssignStaticMemoryInput(child_graph.get()); + runtime_instance->AssignStaticMemoryValueNode(child_graph.get()); } void AscendSession::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *const outputs) { MS_LOG(INFO) << "start"; auto kernel_graph = GetGraph(graph_id); + DumpIR("./run_graph.ir", kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph); // if none of child graph and no anf output exists if (!kernel_graph->executable()) { @@ -1378,10 +1396,10 @@ void AscendSession::SyncInitialTenosrToDevice() { } } -KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, - const std::vector &list) { +std::vector AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list) { MS_EXCEPTION_IF_NULL(new_kernel_graph); - MS_LOG(INFO) << "start split kernel graph:" << new_kernel_graph->graph_id(); + MS_LOG(INFO) << "start contruct splited kernel graph:" << new_kernel_graph->graph_id(); // count the output of every anf node std::set has_output_nodes; for (auto &anf_node : list) { @@ -1390,21 +1408,23 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke } } MS_LOG(INFO) << "Construct input of kernel graph:" << new_kernel_graph->graph_id(); + std::vector call_node_inputs; + auto graph_inputs = new_kernel_graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); // create new parameter from cnode for (auto &anf_node : list) { auto cnode = anf_node->cast(); for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { auto input = cnode->inputs()[input_idx]; MS_EXCEPTION_IF_NULL(input); - if (!input->isa()) { + if (input->isa()) { + graph_inputs->push_back(input); cnode->set_input(input_idx, input); - continue; - } - if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { + } else if (AnfAlgo::GetGraphId(input.get()) != new_kernel_graph->graph_id()) { auto new_parameter = CreateNewParameterFromCNode(input, true, new_kernel_graph.get()); cnode->set_input(input_idx, new_parameter); - new_kernel_graph->SetRealInput(new_parameter, input); } + call_node_inputs.push_back(input); } } MS_LOG(INFO) << "Construct output of kernel graph:" << new_kernel_graph->graph_id(); @@ -1424,7 +1444,7 @@ KernelGraphPtr AscendSession::ConstructSplitedGraph(const KernelGraphPtr &new_ke new_kernel_graph->set_output(new_kernel_graph->NewCNode(make_tuple_inputs)); } MS_LOG(INFO) << "end"; - return new_kernel_graph; + return call_node_inputs; } void AscendSession::SplitGraphs(const KernelGraphPtr &root_graph) { @@ -1438,7 +1458,7 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); auto apply_list = GetCNodes(TopoSort(graph->get_return())); // update the root graph child graph order - graph->UpdateChildGraphOrder(); + AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); // get child list from current graph std::vector> child_graph_lists = GetChildList(*graph, apply_list); auto bind_new_call_to_new_graph = [&](std::vector child_graph_list) -> AnfNodePtr { @@ -1457,7 +1477,8 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { for (auto &child_graph_node : child_graph_list) { AnfAlgo::SetGraphId(child_graph->graph_id(), child_graph_node.get()); } - ConstructSplitedGraph(child_graph, child_graph_list); + auto call_node_args = ConstructSplitedGraph(child_graph, child_graph_list); + std::copy(call_node_args.begin(), call_node_args.end(), std::back_inserter(new_call_input)); auto new_call = graph->NewCNode(new_call_input); AnfAlgo::SetNodeAttr("graph id", MakeValue(graph->graph_id()), new_call); return new_call; @@ -1466,26 +1487,59 @@ void AscendSession::SplitGraph(const KernelGraphPtr &graph) { std::list depend_input = {}; for (size_t call_index = 0; call_index < child_graph_lists.size(); call_index++) { auto call_node = bind_new_call_to_new_graph(child_graph_lists[call_index]); + MS_EXCEPTION_IF_NULL(call_node); + // if call node is the last call of true graph,no need create child graph after that + auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(call_node->cast()); depend_input.push_front(call_node); + if (child_graphs.size() == 1 && child_graphs[0] == graph->parent_graph()) { + break; + } } depend_input.push_front(graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimDepend->name())))); auto depend = graph->NewCNode(std::vector(depend_input.begin(), depend_input.end())); auto new_return_primitive = graph->NewValueNode(NewValueNode(std::make_shared(prim::kPrimReturn->name()))); graph->set_return(graph->NewCNode({new_return_primitive, depend})); + AnfNodePtr pre_call_node = nullptr; + AnfNodePtr cur_call_node = nullptr; + auto iter = depend_input.begin(); + for (++iter; iter != depend_input.end(); ++iter) { + pre_call_node = cur_call_node; + cur_call_node = *iter; + if (pre_call_node != nullptr && cur_call_node != nullptr) { + AscendControlParser::InsertControlDependToGraph(NOT_NULL(graph), NOT_NULL(cur_call_node), + NOT_NULL(pre_call_node)); + } + } } - graph->UpdateChildGraphOrder(); + AscendControlParser::UpdateChildGraphOrder(NOT_NULL(graph)); UpdateRealInput(graph.get()); auto graph_name = std::string("./kernel-graph-").append(std::to_string(graph->graph_id())); DumpIR(graph_name, graph); MS_LOG(INFO) << "split graph[" << graph->graph_id() << "] end"; // recurse to split child graph for (auto &child_graph : graph->child_graph_order()) { - SplitGraph(child_graph); + if (child_graph != graph->parent_graph()) { + SplitGraph(child_graph); + } } } void AscendSession::LinkChildGraphs(NotNull graph) { AscendControlParser::LinkGraph(graph); } +void AscendSession::RootGraphExecutorValidate(NotNull graph) { + AscendControlParser::ExecutorValidate(graph); +} + +void AscendSession::RecurseCompileGraph(const KernelGraphPtr &graph) { + CompileChildGraph(graph); + for (auto child_graph : graph->child_graph_order()) { + if (child_graph == graph->parent_graph()) { + continue; + } + RecurseCompileGraph(child_graph); + } +} + } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index d8b60cf3b3c..26faed15adf 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -104,10 +104,10 @@ class AscendSession : public SessionBasic { void SelectKernelGraphKernel(const KernelGraph &graph) {} void ConvertPredictModel(const KernelGraphPtr graph) {} void HardwareOptimizeGraphs(const KernelGraphPtr graph) {} - void RootGraphExecutorValidate(KernelGraph *graph) {} - void RecurseUpdateAllChildGraohOrder(KernelGraph *root_graph); - KernelGraphPtr ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, const std::vector &list); - void ChildGraphCommunicationDecrease(std::vector> *anf_node_lists); + void RootGraphExecutorValidate(NotNull graph); + std::vector ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph, + const std::vector &list); + void RecurseCompileGraph(const KernelGraphPtr &graph); // merge execution order list of child graphs void MergeGraphExecOrder(); diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 7068dace57b..1db932cd306 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -165,6 +165,21 @@ void KernelGraph::SetExecOrderByDefault() { } } CheckLoop(); + // resort start label / end goto + std::vector re_order; + if (start_label_ != nullptr) { + re_order.push_back(start_label_); + } + for (auto &node : execution_order_) { + if (node == start_label_ || node == end_goto_) { + continue; + } + re_order.push_back(node); + } + if (end_goto_ != nullptr) { + re_order.push_back(end_goto_); + } + execution_order_ = re_order; } void KernelGraph::CheckLoop() { @@ -360,7 +375,8 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) { MS_EXCEPTION_IF_NULL(old_backend_anf); MS_EXCEPTION_IF_NULL(new_backend_anf); - if (old_backend_anf.get() == new_backend_anf.get()) { + if (old_backend_anf == new_backend_anf) { + MS_LOG(INFO) << "old:" << old_backend_anf->DebugString() << ",new:" << new_backend_anf->DebugString(); MS_LOG(EXCEPTION) << "old can't be same with new"; } if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { @@ -569,32 +585,52 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf MS_EXCEPTION_IF_NULL(new_anf_node); MS_EXCEPTION_IF_NULL(inputs_); auto it = node_output_edges_.find(old_anf_node); - if (it == node_output_edges_.end()) { - MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map"; - } - auto &outputs = it->second; - for (auto &output_node : outputs) { - auto output_cnode = output_node.first->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - auto &output_node_inputs = output_cnode->inputs(); - for (size_t i = 1; i < output_node_inputs.size(); i++) { - if (output_node_inputs[i] == old_anf_node) { - output_cnode->set_input(i, new_anf_node); - } - } - // update graph inputs - for (size_t i = 0; i < inputs_->size(); i++) { - if ((*inputs_)[i] == old_anf_node) { - (*inputs_)[i] = new_anf_node; - break; + if (it != node_output_edges_.end()) { + const auto &outputs = it->second; + for (auto &output_node : outputs) { + MS_EXCEPTION_IF_NULL(output_node.first); + auto output_cnode = output_node.first->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + const auto &output_node_inputs = output_cnode->inputs(); + for (size_t i = 1; i < output_node_inputs.size(); i++) { + if (output_node_inputs[i] == old_anf_node) { + output_cnode->set_input(i, new_anf_node); + } + } + // update graph inputs + for (size_t i = 0; i < inputs_->size(); i++) { + if ((*inputs_)[i] == old_anf_node) { + MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() + << ",new graph input:" << new_anf_node->DebugString(); + (*inputs_)[i] = new_anf_node; + break; + } + } + MS_LOG(INFO) << "Inputs of graph id:" << graph_id(); + for (size_t i = 0; i < inputs().size(); i++) { + MS_LOG(INFO) << "[" << i << "]:" << inputs()[i]->DebugString(); } } + // update front to backend map + FrontBackendlMapUpdate(old_anf_node, new_anf_node); + // update output depend relations + node_output_edges_[new_anf_node] = it->second; + (void)node_output_edges_.erase(old_anf_node); + } + // update graph inputs in child graph + auto it_real_inputs = real_inputs_.find(old_anf_node); + if (it_real_inputs != real_inputs_.end()) { + // insert new parameter to map + auto iter = real_inputs_.find(new_anf_node); + if (iter != real_inputs_.end()) { + MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited."; + iter->second = it_real_inputs->second; + } else { + real_inputs_[new_anf_node] = it_real_inputs->second; + } + // erase old parameter in map + real_inputs_.erase(old_anf_node); } - // update front to backend map - FrontBackendlMapUpdate(old_anf_node, new_anf_node); - // update output depend relations - node_output_edges_[new_anf_node] = it->second; - (void)node_output_edges_.erase(old_anf_node); } void KernelGraph::UpdateExecuteKernelStreamLabel() { @@ -603,29 +639,6 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() { } } -void KernelGraph::UpdateChildGraphOrder() { - MS_LOG(INFO) << "graph id:" << graph_id_; - auto call_nodes = FindNodeByPrimitive(std::make_shared(prim::kPrimCall->name())); - for (auto &old_child_graph : child_graph_order_) { - old_child_graph->set_parent_graph(nullptr); - } - child_graph_order_.clear(); - for (auto &call_node : call_nodes) { - MS_EXCEPTION_IF_NULL(call_node); - auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast()); - for (const auto &child_graph : call_child_graphs) { - MS_EXCEPTION_IF_NULL(child_graph); - if (child_graph != parent_graph()) { - child_graph->set_parent_graph(shared_from_this()->cast>()); - child_graph_order_.push_back(child_graph); - } - } - } - for (size_t i = 0; i < child_graph_order_.size(); i++) { - MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]"; - } -} - std::vector> KernelGraph::GetLeafGraphOrder() { std::vector> leaf_graph_order; if (IsLeafGraph()) { @@ -643,9 +656,8 @@ std::vector> KernelGraph::GetLeafGraphOrder() { bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); } std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const { - auto anf_list = TopoSort(get_return()); std::vector result; - for (const auto &anf : anf_list) { + for (const auto &anf : execution_order_) { if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) { result.push_back(anf->cast()); } @@ -653,14 +665,6 @@ std::vector KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi return result; } -std::set KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { - MS_EXCEPTION_IF_NULL(parameter); - if (real_inputs_.find(parameter) == real_inputs_.end()) { - return {}; - } - return real_inputs_[parameter]; -} - void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) { MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(arg); @@ -674,37 +678,41 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar (void)args.insert(arg); } +std::set KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { + MS_EXCEPTION_IF_NULL(parameter); + auto iter = real_inputs_.find(parameter); + if (iter != real_inputs_.end()) { + return iter->second; + } + MS_LOG(EXCEPTION) << parameter->DebugString() << " not found."; +} + void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "Update graph id: " << graph_id_; for (auto &it : real_inputs_) { auto ¶meter = it.first; MS_EXCEPTION_IF_NULL(parameter); auto &real_inputs = it.second; - std::set new_real_inputs; + std::vector new_real_inputs; std::set erase_real_inputs; for (auto &real_input : real_inputs) { // if real input is a call node ,find the child graph output act as the new real input auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0); MS_EXCEPTION_IF_NULL(item_with_index.first); if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) { - MS_LOG(INFO) << "paramter: " << parameter->DebugString() - << " erase real input:" << item_with_index.first->DebugString(); (void)erase_real_inputs.insert(item_with_index.first); - auto call_node_outputs = GetCallRealOutputs(item_with_index.first); - for (auto &call_node_output : call_node_outputs) { - MS_EXCEPTION_IF_NULL(call_node_output); - MS_LOG(INFO) << "paramter: " << parameter->DebugString() - << " insert real input:" << call_node_output->DebugString(); - (void)new_real_inputs.insert(call_node_output); - } + new_real_inputs = GetCallRealOutputs(item_with_index.first); continue; } - for (auto &erase_node : erase_real_inputs) { - (void)real_inputs.erase(erase_node); - } - for (auto &new_real_input : new_real_inputs) { - (void)real_inputs.insert(new_real_input); - } + } + for (auto &erase_node : erase_real_inputs) { + MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString(); + (void)real_inputs.erase(erase_node); + } + for (auto &new_real_input : new_real_inputs) { + MS_LOG(INFO) << "paramter: " << parameter->DebugString() + << " insert real input:" << new_real_input->DebugString(); + (void)real_inputs.insert(new_real_input); } } } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 53e15914e89..524c6b4c684 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -103,10 +103,9 @@ class KernelGraph : public FuncGraph { void UpdateExecuteKernelStreamLabel(); // calculate the leaf graph order of root graph std::vector> GetLeafGraphOrder(); - // update the child graph order of graph - void UpdateChildGraphOrder(); - // get the child graph of current graph - std::vector> child_graph_order() const { return child_graph_order_; } + // the child graph of current graph + const std::vector> &child_graph_order() const { return child_graph_order_; } + void set_child_graph_order(const std::vector> &order) { child_graph_order_ = order; } // checkout whether current graph is leaf graph bool IsLeafGraph() const; @@ -123,6 +122,7 @@ class KernelGraph : public FuncGraph { // find anf node in graph std::vector FindNodeByPrimitive(const PrimitivePtr &primitive) const; // get real inputs + const std::map> &real_inputs() const { return real_inputs_; } std::set GetRealInput(const AnfNodePtr ¶meter); void SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg); // used to dump ir @@ -132,6 +132,8 @@ class KernelGraph : public FuncGraph { void set_start_label(const CNodePtr &start_label) { start_label_ = start_label; } CNodePtr get_start_label() { return start_label_; } + void set_end_goto(const CNodePtr &end_goto) { end_goto_ = end_goto; } + CNodePtr get_end_goto() { return end_goto_; } private: // remove value node form graph @@ -185,6 +187,7 @@ class KernelGraph : public FuncGraph { std::map> real_inputs_; CNodePtr start_label_; + CNodePtr end_goto_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index db6257c815a..7cfe93dab06 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -147,6 +147,7 @@ BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]"; auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); MS_EXCEPTION_IF_NULL(item_with_index.first); + MS_LOG(INFO) << "create tensor for output after visit:" << item_with_index.first->DebugString(); // special handle for maketuple if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { auto cnode = item_with_index.first->cast(); @@ -479,31 +480,12 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) } for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) { - auto anf = cnode->inputs()[input_idx]; + auto anf = cnode->input(input_idx); MS_EXCEPTION_IF_NULL(anf); // anf has been created before if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); continue; - } else if (anf->isa()) { - if (!IsValueNode(anf)) { - // if input is a common value node, - auto new_value_node = CreateNewValueNode(anf, graph); - if (new_value_node != nullptr) { - cnode_inputs.emplace_back(new_value_node); - } - } else { - // if input is a ValueNode - auto new_value_node = CreateValueNodeKernelGraph(anf, graph); - if (new_value_node != nullptr) { - cnode_inputs.emplace_back(new_value_node); - } - } - continue; - } else if (anf->isa()) { - auto new_parameter = CreateNewParameter(anf, graph); - cnode_inputs.push_back(new_parameter); - continue; } MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; } @@ -613,32 +595,22 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP for (const auto &node : node_list) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString(); - if (!node->isa()) { - MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode"; + if (node->isa()) { + (void)CreateNewParameter(node, graph.get()); + continue; + } else if (node->isa()) { + if (!IsValueNode(node)) { + // if input is a common value node, + (void)CreateNewValueNode(node, graph.get()); + } else { + // if input is a ValueNode + auto child_graph = ConstructKernelGraph(AnfAlgo::GetValueNodeFuncGraph(node)); + auto new_value_node = CreateValueNodeKernelGraph(node, graph.get()); + } continue; } else { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - - // recurse control ops: call, partial - auto attr_input = cnode->input(kAnfPrimitiveIndex); - MS_EXCEPTION_IF_NULL(attr_input); - if (IsValueNode(attr_input)) { - // recurse call subgraph - auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input); - ConstructKernelGraph(sub_func_graph); - } else if (IsValueNode(attr_input)) { - auto prim = GetCNodePrimitive(node); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == kPartialOpName) { - // recurse partial subgraph - auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex); - MS_EXCEPTION_IF_NULL(func_graph_node); - auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node); - ConstructKernelGraph(sub_func_graph); - } - } - // create a new cnode object auto new_cnode = CreateNewCNode(cnode, graph.get()); MS_EXCEPTION_IF_NULL(new_cnode); @@ -650,7 +622,21 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP } } } - + auto graph_inputs = graph->MutableInputs(); + MS_EXCEPTION_IF_NULL(graph_inputs); + graph_inputs->clear(); + for (auto ¶meter : func_graph->parameters()) { + MS_EXCEPTION_IF_NULL(parameter); + auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); + if (backend_parameter == nullptr) { + // for example "def f(x,y,z) {return x + y}", parameter z in unused + CreateNewParameterFromParameter(parameter, false, graph.get()); + MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); + continue; + } + MS_LOG(INFO) << "graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString(); + graph_inputs->push_back(backend_parameter); + } MS_EXCEPTION_IF_NULL(context_); FuncGraphManagerPtr manager = context_->manager(); if (manager) { @@ -716,6 +702,11 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_grap const std::vector &input_tensors) const { MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(outputs); + if (!kernel_graph->child_graph_order().empty()) { + // use the last child graph output as the root graph output + UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors); + return; + } auto anf_outputs = kernel_graph->outputs(); for (auto &item : anf_outputs) { MS_LOG(INFO) << "update output[" << item->DebugString() << "]"; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 48db921f9f6..93d5f33cf43 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -487,8 +487,7 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { } void TraverseGraphMap( - const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, - const FuncGraphSet &fgs, + const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphSet &fgs, const std::function(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(tr);