From d3770b5edf160c4da2d07e9fd812be46498bb1c4 Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Thu, 20 Jan 2022 20:19:43 +0800 Subject: [PATCH] Link control arrow from kernel graph exit actor to entrance actor. --- .../runtime/framework/control_node_parser.cc | 48 ++++++++++- .../runtime/framework/control_node_parser.h | 7 +- .../framework/control_node_scheduler.cc | 84 +++++++++++++------ .../framework/control_node_scheduler.h | 1 + 4 files changed, 110 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index 2dac7c7c3d1..b4d0a6b1043 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -795,7 +795,7 @@ void ControlNodeParser::Parse(const std::vector &control_nodes, cons FetchAutoMonadNode(control_nodes); - ParseFirstControlNodeForFuncGraph(control_nodes); + ParseFirstControlNodeAndKernelGraphForFuncGraph(control_nodes); } bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) { @@ -1671,7 +1671,7 @@ AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNod return sub_front_node_to_root_front_node_[sub_front_node]; } -void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector &control_nodes) { +void ControlNodeParser::ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector &control_nodes) { for (const auto &control_node : control_nodes) { std::set checked_nodes; if (((AnfAlgo::IsCallNode(control_node) && @@ -1681,6 +1681,50 @@ void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vectorfunc_graph(); MS_EXCEPTION_IF_NULL(func_graph); (void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node); + + if (!AnfAlgo::IsCallNode(control_node)) { + continue; + } + + // If there is a recursive call node in the funcgraph, the kernel graph of the topo sort before the call node + // needs to be executed before the call recursion, that is, the kernel graph whose level is less than the call + // node needs to link a control arrow to the corresponding entry actor. + // Fetch the level of control node. + const auto &level_iter = node_to_level_.find(control_node); + if (level_iter == node_to_level_.end()) { + MS_LOG(WARNING) << "Failed to get level for call node:" << control_node->DebugString(); + continue; + } + + // Fetch all of the kernel graph group info whose level less than the control node. + const auto &graph_group_iter = func_graph_to_kernel_graph_groups_.find(func_graph); + if (graph_group_iter == func_graph_to_kernel_graph_groups_.end()) { + continue; + } + for (const auto &kernel_graphs : graph_group_iter->second) { + // Fetch one graph from the group. + KernelGraphPtr dst_graph = nullptr; + for (const auto &graph : kernel_graphs) { + MS_EXCEPTION_IF_NULL(graph); + if (graph->execution_order().empty()) { + continue; + } + dst_graph = graph; + break; + } + if (dst_graph == nullptr) { + continue; + } + + // Fetch the group info. + const auto &group_info_iter = kernel_graphs_to_group_info_.find(dst_graph); + if (group_info_iter == kernel_graphs_to_group_info_.end()) { + MS_LOG(EXCEPTION) << "Failed to get group info for kernel_graph:" << dst_graph->ToString(); + } + if (group_info_iter->second->level_ < level_iter->second) { + func_graph_to_first_kernel_graphs_[func_graph].emplace(group_info_iter->second); + } + } } } } diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 4b60da987c5..87bc4d4ee8d 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -198,7 +198,7 @@ class ControlNodeParser { // the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all // actors in the funcgraph that need to send control messages to the entrance. // These node are control nodes without control node input in the topological sort of the funcgraph. - void ParseFirstControlNodeForFuncGraph(const std::vector &control_nodes); + void ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector &control_nodes); // Parse all funcgraphs that call nodes may call. void ParseCallNodeToFuncGraph(const std::vector &control_nodes); @@ -273,6 +273,11 @@ class ControlNodeParser { mindspore::HashMap kernel_to_call_nodes_; // Control nodes without a control node input in the topological sorting of funcgraph. mindspore::HashMap> func_graph_to_first_control_nodes_; + // Kernel graphs need to link a control arrow to its entrance actor. + // In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the + // entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph + // to the entrance actor. + mindspore::HashMap> func_graph_to_first_kernel_graphs_; // Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node // belong. std::set unrecursion_call_nodes_; diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc index 7a0ac679379..dbf42221f91 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc @@ -829,33 +829,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor MS_EXCEPTION_IF_NULL(control_actor_set); const auto &parser = graph_compiler_info.control_node_parser_; MS_EXCEPTION_IF_NULL(parser); - - // Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph - // stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the - // graph to run the next set of real parameters. The corresponding nodes of the actors that need to send the control - // arrow have been parsed in the control node parser. - for (const auto &graph_to_nodes : parser->func_graph_to_first_control_nodes_) { - // Fetch the entrance actor. - const auto &func_graph = graph_to_nodes.first; - MS_EXCEPTION_IF_NULL(func_graph); - auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix; - auto entrance_actor = dynamic_cast(FetchActor(actor_name)); - MS_EXCEPTION_IF_NULL(entrance_actor); - - const auto &nodes = graph_to_nodes.second; - for (const auto &node : nodes) { - // Fetch the source actor of control arrow. - MS_EXCEPTION_IF_NULL(node); - if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - actor_name = func_graph->ToString() + kExitActorNameSuffix; - } else { - actor_name = GetActorName(node); - } - auto from_actor = dynamic_cast(FetchActor(actor_name)); - MS_EXCEPTION_IF_NULL(from_actor); - LinkLoopBodyControlArrow(from_actor, entrance_actor); - } - } + LinkControlArrowForEntranceActor(actor_set, graph_compiler_info); // When the switch actor and gather actor have no input, need to link a control arrow from entrance actor. std::vector need_check_control_actors; @@ -939,6 +913,62 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor } } +void ControlNodeScheduler::LinkControlArrowForEntranceActor(ActorSet *const actor_set, + const GraphCompilerInfo &graph_compiler_info) { + MS_EXCEPTION_IF_NULL(actor_set); + auto control_actor_set = actor_set->control_actors_.get(); + MS_EXCEPTION_IF_NULL(control_actor_set); + const auto &parser = graph_compiler_info.control_node_parser_; + MS_EXCEPTION_IF_NULL(parser); + + // Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph + // stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the + // graph to run the next set of real parameters. The corresponding nodes of the actors that need to send the control + // arrow have been parsed in the control node parser. + for (const auto &graph_to_nodes : parser->func_graph_to_first_control_nodes_) { + // Fetch the entrance actor. + const auto &func_graph = graph_to_nodes.first; + MS_EXCEPTION_IF_NULL(func_graph); + auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix; + auto entrance_actor = dynamic_cast(FetchActor(actor_name)); + MS_EXCEPTION_IF_NULL(entrance_actor); + + const auto &nodes = graph_to_nodes.second; + for (const auto &node : nodes) { + // Fetch the source actor of control arrow. + MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + actor_name = func_graph->ToString() + kExitActorNameSuffix; + } else { + actor_name = GetActorName(node); + } + auto from_actor = dynamic_cast(FetchActor(actor_name)); + MS_EXCEPTION_IF_NULL(from_actor); + LinkLoopBodyControlArrow(from_actor, entrance_actor); + } + } + + // In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the + // entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph + // to the entrance actor. + for (const auto &func_graph_to_group_info : parser->func_graph_to_first_kernel_graphs_) { + const auto &func_graph = func_graph_to_group_info.first; + MS_EXCEPTION_IF_NULL(func_graph); + auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix; + auto actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(actor); + auto entrance_actor = dynamic_cast(actor); + MS_EXCEPTION_IF_NULL(entrance_actor); + for (const auto &group_info : func_graph_to_group_info.second) { + MS_EXCEPTION_IF_NULL(group_info); + actor_name = group_info->group_name_ + kExitActorNameSuffix; + auto from_actor = FetchActor(actor_name); + MS_EXCEPTION_IF_NULL(from_actor); + LinkLoopBodyControlArrow(from_actor, entrance_actor); + } + } +} + void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) { MS_EXCEPTION_IF_NULL(actor_set); diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.h b/mindspore/ccsrc/runtime/framework/control_node_scheduler.h index c6550fb548e..a2e7588da72 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.h @@ -58,6 +58,7 @@ class ControlNodeScheduler { std::vector *const stack_actors); // Interface to link control actors. void LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); + void LinkControlArrowForEntranceActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set); // Link all arrows between control actors. void LinkArrowForControlActor(ControlActorSet *const control_actor_set, const GraphCompilerInfo &graph_compiler_info);