Link control arrow from kernel graph exit actor to entrance actor.

This commit is contained in:
gaoyong10 2022-01-20 20:19:43 +08:00
parent 2789fceade
commit d3770b5edf
4 changed files with 110 additions and 30 deletions

View File

@ -795,7 +795,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
FetchAutoMonadNode(control_nodes); FetchAutoMonadNode(control_nodes);
ParseFirstControlNodeForFuncGraph(control_nodes); ParseFirstControlNodeAndKernelGraphForFuncGraph(control_nodes);
} }
bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) { bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) {
@ -1671,7 +1671,7 @@ AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNod
return sub_front_node_to_root_front_node_[sub_front_node]; return sub_front_node_to_root_front_node_[sub_front_node];
} }
void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) { void ControlNodeParser::ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
for (const auto &control_node : control_nodes) { for (const auto &control_node : control_nodes) {
std::set<AnfNodePtr> checked_nodes; std::set<AnfNodePtr> checked_nodes;
if (((AnfAlgo::IsCallNode(control_node) && if (((AnfAlgo::IsCallNode(control_node) &&
@ -1681,6 +1681,50 @@ void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfN
const auto &func_graph = control_node->func_graph(); const auto &func_graph = control_node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
(void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node); (void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
if (!AnfAlgo::IsCallNode(control_node)) {
continue;
}
// If there is a recursive call node in the funcgraph, the kernel graph of the topo sort before the call node
// needs to be executed before the call recursion, that is, the kernel graph whose level is less than the call
// node needs to link a control arrow to the corresponding entry actor.
// Fetch the level of control node.
const auto &level_iter = node_to_level_.find(control_node);
if (level_iter == node_to_level_.end()) {
MS_LOG(WARNING) << "Failed to get level for call node:" << control_node->DebugString();
continue;
}
// Fetch all of the kernel graph group info whose level less than the control node.
const auto &graph_group_iter = func_graph_to_kernel_graph_groups_.find(func_graph);
if (graph_group_iter == func_graph_to_kernel_graph_groups_.end()) {
continue;
}
for (const auto &kernel_graphs : graph_group_iter->second) {
// Fetch one graph from the group.
KernelGraphPtr dst_graph = nullptr;
for (const auto &graph : kernel_graphs) {
MS_EXCEPTION_IF_NULL(graph);
if (graph->execution_order().empty()) {
continue;
}
dst_graph = graph;
break;
}
if (dst_graph == nullptr) {
continue;
}
// Fetch the group info.
const auto &group_info_iter = kernel_graphs_to_group_info_.find(dst_graph);
if (group_info_iter == kernel_graphs_to_group_info_.end()) {
MS_LOG(EXCEPTION) << "Failed to get group info for kernel_graph:" << dst_graph->ToString();
}
if (group_info_iter->second->level_ < level_iter->second) {
func_graph_to_first_kernel_graphs_[func_graph].emplace(group_info_iter->second);
}
}
} }
} }
} }

View File

@ -198,7 +198,7 @@ class ControlNodeParser {
// the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all // the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all
// actors in the funcgraph that need to send control messages to the entrance. // actors in the funcgraph that need to send control messages to the entrance.
// These node are control nodes without control node input in the topological sort of the funcgraph. // These node are control nodes without control node input in the topological sort of the funcgraph.
void ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes); void ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
// Parse all funcgraphs that call nodes may call. // Parse all funcgraphs that call nodes may call.
void ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes); void ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
@ -273,6 +273,11 @@ class ControlNodeParser {
mindspore::HashMap<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_; mindspore::HashMap<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
// Control nodes without a control node input in the topological sorting of funcgraph. // Control nodes without a control node input in the topological sorting of funcgraph.
mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_; mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
// Kernel graphs need to link a control arrow to its entrance actor.
// In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the
// entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph
// to the entrance actor.
mindspore::HashMap<FuncGraphPtr, std::set<KernelGraphGroupInfoPtr>> func_graph_to_first_kernel_graphs_;
// Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node // Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node
// belong. // belong.
std::set<AnfNodePtr> unrecursion_call_nodes_; std::set<AnfNodePtr> unrecursion_call_nodes_;

View File

@ -829,33 +829,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
MS_EXCEPTION_IF_NULL(control_actor_set); MS_EXCEPTION_IF_NULL(control_actor_set);
const auto &parser = graph_compiler_info.control_node_parser_; const auto &parser = graph_compiler_info.control_node_parser_;
MS_EXCEPTION_IF_NULL(parser); MS_EXCEPTION_IF_NULL(parser);
LinkControlArrowForEntranceActor(actor_set, graph_compiler_info);
// Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph
// stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the
// graph to run the next set of real parameters. The corresponding nodes of the actors that need to send the control
// arrow have been parsed in the control node parser.
for (const auto &graph_to_nodes : parser->func_graph_to_first_control_nodes_) {
// Fetch the entrance actor.
const auto &func_graph = graph_to_nodes.first;
MS_EXCEPTION_IF_NULL(func_graph);
auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
auto entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
MS_EXCEPTION_IF_NULL(entrance_actor);
const auto &nodes = graph_to_nodes.second;
for (const auto &node : nodes) {
// Fetch the source actor of control arrow.
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
actor_name = func_graph->ToString() + kExitActorNameSuffix;
} else {
actor_name = GetActorName(node);
}
auto from_actor = dynamic_cast<ControlActor *>(FetchActor(actor_name));
MS_EXCEPTION_IF_NULL(from_actor);
LinkLoopBodyControlArrow(from_actor, entrance_actor);
}
}
// When the switch actor and gather actor have no input, need to link a control arrow from entrance actor. // When the switch actor and gather actor have no input, need to link a control arrow from entrance actor.
std::vector<ControlActor *> need_check_control_actors; std::vector<ControlActor *> need_check_control_actors;
@ -939,6 +913,62 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
} }
} }
void ControlNodeScheduler::LinkControlArrowForEntranceActor(ActorSet *const actor_set,
const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set);
auto control_actor_set = actor_set->control_actors_.get();
MS_EXCEPTION_IF_NULL(control_actor_set);
const auto &parser = graph_compiler_info.control_node_parser_;
MS_EXCEPTION_IF_NULL(parser);
// Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph
// stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the
// graph to run the next set of real parameters. The corresponding nodes of the actors that need to send the control
// arrow have been parsed in the control node parser.
for (const auto &graph_to_nodes : parser->func_graph_to_first_control_nodes_) {
// Fetch the entrance actor.
const auto &func_graph = graph_to_nodes.first;
MS_EXCEPTION_IF_NULL(func_graph);
auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
auto entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
MS_EXCEPTION_IF_NULL(entrance_actor);
const auto &nodes = graph_to_nodes.second;
for (const auto &node : nodes) {
// Fetch the source actor of control arrow.
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
actor_name = func_graph->ToString() + kExitActorNameSuffix;
} else {
actor_name = GetActorName(node);
}
auto from_actor = dynamic_cast<ControlActor *>(FetchActor(actor_name));
MS_EXCEPTION_IF_NULL(from_actor);
LinkLoopBodyControlArrow(from_actor, entrance_actor);
}
}
// In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the
// entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph
// to the entrance actor.
for (const auto &func_graph_to_group_info : parser->func_graph_to_first_kernel_graphs_) {
const auto &func_graph = func_graph_to_group_info.first;
MS_EXCEPTION_IF_NULL(func_graph);
auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
auto actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(actor);
auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
MS_EXCEPTION_IF_NULL(entrance_actor);
for (const auto &group_info : func_graph_to_group_info.second) {
MS_EXCEPTION_IF_NULL(group_info);
actor_name = group_info->group_name_ + kExitActorNameSuffix;
auto from_actor = FetchActor(actor_name);
MS_EXCEPTION_IF_NULL(from_actor);
LinkLoopBodyControlArrow(from_actor, entrance_actor);
}
}
}
void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set, void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set,
const GraphCompilerInfo &graph_compiler_info) { const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set); MS_EXCEPTION_IF_NULL(actor_set);

View File

@ -58,6 +58,7 @@ class ControlNodeScheduler {
std::vector<StackActorPtr> *const stack_actors); std::vector<StackActorPtr> *const stack_actors);
// Interface to link control actors. // Interface to link control actors.
void LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowForEntranceActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set); void LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set);
// Link all arrows between control actors. // Link all arrows between control actors.
void LinkArrowForControlActor(ControlActorSet *const control_actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkArrowForControlActor(ControlActorSet *const control_actor_set, const GraphCompilerInfo &graph_compiler_info);