From 434a6aa0db754315124c43ef0b16686bd95ffd3f Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Mon, 8 Jun 2020 15:53:10 +0800 Subject: [PATCH] new control sink support resnet50 Signed-off-by: zhoufeng --- .../device/ascend/ascend_label_assign.cc | 16 +- .../ccsrc/session/ascend_control_parser.cc | 160 ++++++++---------- .../ccsrc/session/ascend_control_parser.h | 3 - mindspore/ccsrc/session/kernel_graph.cc | 4 +- mindspore/ccsrc/session/session_basic.cc | 6 - 5 files changed, 76 insertions(+), 113 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc index 9908b5d03d8..39e9feb73d7 100644 --- a/mindspore/ccsrc/device/ascend/ascend_label_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_label_assign.cc @@ -33,11 +33,9 @@ static void UpdateLabelGoto(NotNull node) { if (node->size() <= kLabelGotoLabelId) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " has invalid input size " << node->size(); } - auto label_set = AnfAlgo::GetCNodePrimitive(node->input(kLabelGotoLabelId)); - MS_EXCEPTION_IF_NULL(label_set); - auto value = label_set->GetAttr(kAttrLabelIndex); - MS_EXCEPTION_IF_NULL(value); - uint32_t goto_label_id = GetValue(value); + + auto input = node->input(kLabelGotoLabelId); + uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(goto_label_id), node.get()); MS_LOG(INFO) << "Node " << node->DebugString() << " goto label id " << goto_label_id; node->set_inputs({node->input(0)}); @@ -57,11 +55,7 @@ static void UpdateLabelSwitch(NotNull node) { break; } - auto label_set = AnfAlgo::GetCNodePrimitive(input); - MS_EXCEPTION_IF_NULL(label_set); - auto value = label_set->GetAttr(kAttrLabelIndex); - MS_EXCEPTION_IF_NULL(value); - uint32_t goto_label_id = GetValue(value); + uint32_t goto_label_id = AnfAlgo::GetNodeAttr(input, kAttrLabelIndex); label_list.push_back(goto_label_id); MS_LOG(INFO) << "Switch " << node->DebugString() << " case " << i - kLabelSwitchLabelId << ": id " << goto_label_id; } @@ -154,7 +148,7 @@ uint32_t AscendLabelAssign::GetLabelNum(NotNull gr std::lock_guard lock(label_num_mutex_); auto iter = label_num_.find(graph.get()); if (iter == label_num_.end()) { - MS_LOG(WARNING) << "Graph " << graph->ToString() << " has not assigned label."; + MS_LOG(DEBUG) << "Graph " << graph->ToString() << " has not assigned label, defalut is 1."; return 1; } return iter->second; diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 55dfbcbb37b..b0dc1cc523c 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -33,31 +33,6 @@ static constexpr size_t kCNodeSwitchLayerLength = 3; namespace mindspore { namespace session { -void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { - for (auto &iter : graph_id_map) { - auto &kg = iter.second; - MS_EXCEPTION_IF_NULL(kg); - auto real_inputs = kg->real_inputs(); - for (auto &it : real_inputs) { - auto ¶meter = it.first; - auto &args = it.second; - for (auto &arg : args) { - MS_EXCEPTION_IF_NULL(arg); - if (arg->isa()) { - MS_LOG(INFO) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() - << ", arg:" << arg->DebugString(); - continue; - } - auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); - if (target_graph_iter == graph_id_map.end()) { - MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; - } - InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); - } - } - } -} - static void InitUnionFindSet(NotNull kg, const NotNull *> union_find_set, const NotNull *> memo) { if (memo->find(kg.get()) != memo->end()) { @@ -89,6 +64,7 @@ static void UnionParentParameter(NotNull kg, const NotNullinsert(kg.get()); + const std::map> &real_inputs = kg->real_inputs(); for (auto &iter : real_inputs) { auto ¶ = iter.first; @@ -150,11 +126,10 @@ static void ReuseParameter(NotNull root_kg, NotNullinputs(); root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); for (auto &node : parameter_reuse_set) { - if (root_inputs_set.find(node) == root_inputs_set.end()) { - continue; + if (root_inputs_set.find(node) != root_inputs_set.end()) { + main_parameter = node; + break; } - - main_parameter = node; } std::set memo; @@ -162,9 +137,18 @@ static void ReuseParameter(NotNull root_kg, NotNull &list, size_t start) { + for (size_t i = start; i < list.size() - 1; ++i) { + if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { + return list[i]; + } + } + return nullptr; +} + void AscendControlParser::LinkGraph(NotNull kg) { std::set memo; - ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); + (void)ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); std::map graph_id_map; for (auto &g : memo) { if (graph_id_map.find(g->graph_id()) != graph_id_map.end()) { @@ -181,13 +165,34 @@ void AscendControlParser::LinkGraph(NotNull kg) { ChildGraphDataAssign(graph_id_map); } -CNodePtr AscendControlParser::GetNextRealKernel(const std::vector &list, size_t start) { - for (size_t i = start; i < list.size() - 1; ++i) { - if (!IsPrimitiveCNode(list[i], prim::kPrimPartial) && AnfAlgo::IsRealKernel(list[i])) { - return list[i]; +void AscendControlParser::ExecutorValidate(NotNull root_graph) { + std::set memo; + (void)RecurseGraph(root_graph, NOT_NULL(&memo)); +} + +void AscendControlParser::ChildGraphDataAssign(const std::map &graph_id_map) { + for (auto &iter : graph_id_map) { + auto &kg = iter.second; + MS_EXCEPTION_IF_NULL(kg); + auto real_inputs = kg->real_inputs(); + for (auto &it : real_inputs) { + auto ¶meter = it.first; + auto &args = it.second; + for (auto &arg : args) { + MS_EXCEPTION_IF_NULL(arg); + if (arg->isa()) { + MS_LOG(DEBUG) << "Parameter should be reused, no need insert assign, parameter: " << parameter->DebugString() + << ", arg:" << arg->DebugString(); + continue; + } + auto target_graph_iter = graph_id_map.find(AnfAlgo::GetGraphId(arg.get())); + if (target_graph_iter == graph_id_map.end()) { + MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; + } + InsertAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter)); + } } } - return nullptr; } NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, const CNodePtr &last_node, @@ -212,9 +217,16 @@ NotNull AscendControlParser::ProcessKernelGraph(NotNullToString() << " has no cnodes!"; } // 4. insert first_label - auto start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); - MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); - kg->set_start_label(start_label); + CNodePtr start_label; + if (last_node != nullptr && last_label != nullptr) { + start_label = kg->NewCNode({std::make_shared(std::make_shared(kLabelSetOpName))}); + MS_LOG(INFO) << "Insert start label " << start_label->DebugString() << " to " << kg->ToString(); + kg->set_start_label(start_label); + } else { + // no goto node will jump to start label of root graph, so return a fake label + start_label = std::make_shared(std::vector(), FuncGraphPtr(nullptr)); + } + // 5. traverse for (size_t i = 0; i < nodes.size(); ++i) { auto &cnode = nodes[i]; @@ -249,11 +261,10 @@ NotNull AscendControlParser::ProcessKernelGraph(NotNull kg, NotNull attch_node) { - std::vector inputs = {NewValueNode(std::make_shared("depend"))}; auto return_node = kg->get_return(); MS_EXCEPTION_IF_NULL(return_node); - inputs.push_back(return_node->input(1)); - inputs.push_back(attch_node.get()); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + return_node->input(1), attch_node.get()}; auto depend_node = kg->NewCNode(inputs); return_node->set_input(1, depend_node); } @@ -407,9 +418,9 @@ std::tuple AscendControlParser::ParsePartial(NotNullsize() < kCNodePartialLength) { MS_LOG(EXCEPTION) << "Inputs of partial node must more than " << kCNodePartialLength; } + auto partial_inputs = partial_cnode->inputs(); auto branch_kg = GetValueNode(partial_inputs[kCNodePartialFunc]); - return {partial_cnode, branch_kg}; } @@ -425,7 +436,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNul MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " << to->DebugString(); // config inputs of assign node - std::vector inputs = {NewValueNode(std::make_shared("Assign")), to, from}; + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimAssign->name())), to, from}; // generate a new cnode auto assign_node = kg->NewCNode(inputs); MS_EXCEPTION_IF_NULL(assign_node); @@ -434,11 +445,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNul InsertDependToGraph(kg, NOT_NULL(assign_node)); } -void AscendControlParser::ExecutorValidate(NotNull root_graph) { - std::set memo; - (void)RecurseGraph(root_graph, NOT_NULL(&memo)); -} - std::vector AscendControlParser::RecurseGraph(NotNull graph, const NotNull *> memo) { MS_LOG(INFO) << "graph:" << graph->graph_id() << " start"; @@ -457,29 +463,24 @@ std::vector AscendControlParser::RecurseGraph(NotNull if (node == graph->get_end_goto()) { continue; } - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { - if (!CheckLabelIndex(child_order_index, 0, node, graph)) { - MS_LOG(EXCEPTION) << "Check label index fail"; - } - auto child_graph = graph->child_graph_order()[child_order_index++]; - if (child_graph == graph->parent_graph()) { - continue; - } - auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); - execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); - } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { - std::vector label_switch_list = GetLabelSwitchList(node); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) { + std::vector label_switch_list = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) { if (!CheckLabelIndex(child_order_index, *iter, node, graph)) { MS_LOG(EXCEPTION) << "Check label index fail"; } auto child_graph = graph->child_graph_order()[child_order_index++]; - if (child_graph == graph->parent_graph()) { - continue; - } auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); } + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) { + uint32_t label_index = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + if (!CheckLabelIndex(child_order_index, label_index, node, graph)) { + MS_LOG(EXCEPTION) << "Check label index fail"; + } + auto child_graph = graph->child_graph_order()[child_order_index++]; + auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo); + execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end()); } } graph->set_execution_order(execution_order); @@ -487,15 +488,6 @@ std::vector AscendControlParser::RecurseGraph(NotNull return execution_order; } -std::vector AscendControlParser::GetLabelSwitchList(const CNodePtr &node) { - if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) { - MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list"; - } - auto primitive = AnfAlgo::GetCNodePrimitive(node); - MS_EXCEPTION_IF_NULL(primitive); - return GetValue>(primitive->GetAttr(kAttrLabelSwitchList)); -} - bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label, NotNull graph) { const std::vector> &child_graph_order = graph->child_graph_order(); @@ -504,33 +496,19 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i MS_LOG(EXCEPTION) << "Child graph order is wrong, graph " << graph->ToString() << " child graph size " << child_graph_order.size() << " goto index " << order_index; } - - if (AnfAlgo::CheckPrimitiveType(cur_label, prim::kPrimLabelGoto)) { - // check label_goto and start_label in child graph - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_label)) { - MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; - } - auto primitive = AnfAlgo::GetCNodePrimitive(cur_label); - MS_EXCEPTION_IF_NULL(primitive); - uint32_t label_goto_index = GetValue(primitive->GetAttr(kAttrLabelIndex)); - label_index = label_goto_index; - } - // get start_label_set_index of child graph auto child_graph = child_graph_order[order_index]; MS_EXCEPTION_IF_NULL(child_graph); + + // get start_label_set_index of child graph auto start_label_set = child_graph->get_start_label(); - if (!AnfAlgo::HasNodeAttr(kAttrLabelIndex, start_label_set)) { - MS_LOG(EXCEPTION) << "LabelSetKernel has no attr label_index"; - } - auto start_primitive = AnfAlgo::GetCNodePrimitive(start_label_set); - MS_EXCEPTION_IF_NULL(start_primitive); - uint32_t start_label_set_index = GetValue(start_primitive->GetAttr(kAttrLabelIndex)); + uint32_t start_label_set_index = AnfAlgo::GetNodeAttr(start_label_set, kAttrLabelIndex); if (label_index != start_label_set_index) { MS_LOG(WARNING) << cur_label->DebugString() << " index " << label_index << " but " << start_label_set->DebugString() << " index " << start_label_set_index << " current child graph order : " << order_index; return false; + } else { + return true; } - return true; } void AscendControlParser::UpdateChildGraphOrder(NotNull kg) { diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index cee3816a6e5..05f5e197294 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -54,10 +54,7 @@ class AscendControlParser { static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); - static CNodePtr GetNextRealKernel(const std::vector &list, size_t start); - // root graph order - static std::vector GetLabelSwitchList(const CNodePtr &node); static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, NotNull graph); static std::vector RecurseGraph(NotNull graph, diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 8fa29ae20ff..9adf3ca97b9 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -377,8 +377,8 @@ void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, cons MS_EXCEPTION_IF_NULL(old_backend_anf); MS_EXCEPTION_IF_NULL(new_backend_anf); if (old_backend_anf == new_backend_anf) { - MS_LOG(INFO) << "old:" << old_backend_anf->DebugString() << ",new:" << new_backend_anf->DebugString(); - MS_LOG(EXCEPTION) << "old can't be same with new"; + MS_LOG(DEBUG) << "old same with new:" << old_backend_anf->DebugString(); + return; } if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) { MS_LOG(DEBUG) << "old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map"; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index e36e0d4410d..93cfc6bbcdf 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -620,12 +620,6 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null. graph->set_output_null(is_trace_back); AddParameterToGraphInputs(func_graph->parameters(), graph.get()); - MS_EXCEPTION_IF_NULL(context_); - FuncGraphManagerPtr manager = MakeManager({graph}); - if (manager) { - manager->AddFuncGraph(graph); - graph->set_manager(manager); - } graph->SetExecOrderByDefault(); if (ExistSummaryNode(graph.get())) { graph->set_summary_node_exist(true);