Link control arrow from kernel graph exit actor to entrance actor.
This commit is contained in:
parent
2789fceade
commit
d3770b5edf
|
@ -795,7 +795,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons
|
||||||
|
|
||||||
FetchAutoMonadNode(control_nodes);
|
FetchAutoMonadNode(control_nodes);
|
||||||
|
|
||||||
ParseFirstControlNodeForFuncGraph(control_nodes);
|
ParseFirstControlNodeAndKernelGraphForFuncGraph(control_nodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) {
|
bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) {
|
||||||
|
@ -1671,7 +1671,7 @@ AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNod
|
||||||
return sub_front_node_to_root_front_node_[sub_front_node];
|
return sub_front_node_to_root_front_node_[sub_front_node];
|
||||||
}
|
}
|
||||||
|
|
||||||
void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
|
void ControlNodeParser::ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes) {
|
||||||
for (const auto &control_node : control_nodes) {
|
for (const auto &control_node : control_nodes) {
|
||||||
std::set<AnfNodePtr> checked_nodes;
|
std::set<AnfNodePtr> checked_nodes;
|
||||||
if (((AnfAlgo::IsCallNode(control_node) &&
|
if (((AnfAlgo::IsCallNode(control_node) &&
|
||||||
|
@ -1681,6 +1681,50 @@ void ControlNodeParser::ParseFirstControlNodeForFuncGraph(const std::vector<AnfN
|
||||||
const auto &func_graph = control_node->func_graph();
|
const auto &func_graph = control_node->func_graph();
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
(void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
|
(void)func_graph_to_first_control_nodes_[func_graph].emplace(control_node);
|
||||||
|
|
||||||
|
if (!AnfAlgo::IsCallNode(control_node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is a recursive call node in the funcgraph, the kernel graph of the topo sort before the call node
|
||||||
|
// needs to be executed before the call recursion, that is, the kernel graph whose level is less than the call
|
||||||
|
// node needs to link a control arrow to the corresponding entry actor.
|
||||||
|
// Fetch the level of control node.
|
||||||
|
const auto &level_iter = node_to_level_.find(control_node);
|
||||||
|
if (level_iter == node_to_level_.end()) {
|
||||||
|
MS_LOG(WARNING) << "Failed to get level for call node:" << control_node->DebugString();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch all of the kernel graph group info whose level less than the control node.
|
||||||
|
const auto &graph_group_iter = func_graph_to_kernel_graph_groups_.find(func_graph);
|
||||||
|
if (graph_group_iter == func_graph_to_kernel_graph_groups_.end()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (const auto &kernel_graphs : graph_group_iter->second) {
|
||||||
|
// Fetch one graph from the group.
|
||||||
|
KernelGraphPtr dst_graph = nullptr;
|
||||||
|
for (const auto &graph : kernel_graphs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
if (graph->execution_order().empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
dst_graph = graph;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dst_graph == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch the group info.
|
||||||
|
const auto &group_info_iter = kernel_graphs_to_group_info_.find(dst_graph);
|
||||||
|
if (group_info_iter == kernel_graphs_to_group_info_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failed to get group info for kernel_graph:" << dst_graph->ToString();
|
||||||
|
}
|
||||||
|
if (group_info_iter->second->level_ < level_iter->second) {
|
||||||
|
func_graph_to_first_kernel_graphs_[func_graph].emplace(group_info_iter->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -198,7 +198,7 @@ class ControlNodeParser {
|
||||||
// the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all
|
// the entrance actor so that it can process next parameters. This is used to obtain the nodes corresponding to all
|
||||||
// actors in the funcgraph that need to send control messages to the entrance.
|
// actors in the funcgraph that need to send control messages to the entrance.
|
||||||
// These node are control nodes without control node input in the topological sort of the funcgraph.
|
// These node are control nodes without control node input in the topological sort of the funcgraph.
|
||||||
void ParseFirstControlNodeForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
|
void ParseFirstControlNodeAndKernelGraphForFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
|
||||||
// Parse all funcgraphs that call nodes may call.
|
// Parse all funcgraphs that call nodes may call.
|
||||||
void ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
|
void ParseCallNodeToFuncGraph(const std::vector<AnfNodePtr> &control_nodes);
|
||||||
|
|
||||||
|
@ -273,6 +273,11 @@ class ControlNodeParser {
|
||||||
mindspore::HashMap<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
|
mindspore::HashMap<AnfNodePtr, AnfNodePtr> kernel_to_call_nodes_;
|
||||||
// Control nodes without a control node input in the topological sorting of funcgraph.
|
// Control nodes without a control node input in the topological sorting of funcgraph.
|
||||||
mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
|
mindspore::HashMap<FuncGraphPtr, std::set<AnfNodePtr>> func_graph_to_first_control_nodes_;
|
||||||
|
// Kernel graphs need to link a control arrow to its entrance actor.
|
||||||
|
// In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the
|
||||||
|
// entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph
|
||||||
|
// to the entrance actor.
|
||||||
|
mindspore::HashMap<FuncGraphPtr, std::set<KernelGraphGroupInfoPtr>> func_graph_to_first_kernel_graphs_;
|
||||||
// Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node
|
// Call nodes without recursive call. The funcgraphs of the call will not call the funcgraph where the call node
|
||||||
// belong.
|
// belong.
|
||||||
std::set<AnfNodePtr> unrecursion_call_nodes_;
|
std::set<AnfNodePtr> unrecursion_call_nodes_;
|
||||||
|
|
|
@ -829,33 +829,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
||||||
MS_EXCEPTION_IF_NULL(control_actor_set);
|
MS_EXCEPTION_IF_NULL(control_actor_set);
|
||||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||||
MS_EXCEPTION_IF_NULL(parser);
|
MS_EXCEPTION_IF_NULL(parser);
|
||||||
|
LinkControlArrowForEntranceActor(actor_set, graph_compiler_info);
|
||||||
// Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph
|
|
||||||
// stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the
|
|
||||||
// graph to run the next set of real parameters. The corresponding nodes of the actors that need to send the control
|
|
||||||
// arrow have been parsed in the control node parser.
|
|
||||||
for (const auto &graph_to_nodes : parser->func_graph_to_first_control_nodes_) {
|
|
||||||
// Fetch the entrance actor.
|
|
||||||
const auto &func_graph = graph_to_nodes.first;
|
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
|
||||||
auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
|
|
||||||
auto entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
|
|
||||||
MS_EXCEPTION_IF_NULL(entrance_actor);
|
|
||||||
|
|
||||||
const auto &nodes = graph_to_nodes.second;
|
|
||||||
for (const auto &node : nodes) {
|
|
||||||
// Fetch the source actor of control arrow.
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
|
||||||
actor_name = func_graph->ToString() + kExitActorNameSuffix;
|
|
||||||
} else {
|
|
||||||
actor_name = GetActorName(node);
|
|
||||||
}
|
|
||||||
auto from_actor = dynamic_cast<ControlActor *>(FetchActor(actor_name));
|
|
||||||
MS_EXCEPTION_IF_NULL(from_actor);
|
|
||||||
LinkLoopBodyControlArrow(from_actor, entrance_actor);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// When the switch actor and gather actor have no input, need to link a control arrow from entrance actor.
|
// When the switch actor and gather actor have no input, need to link a control arrow from entrance actor.
|
||||||
std::vector<ControlActor *> need_check_control_actors;
|
std::vector<ControlActor *> need_check_control_actors;
|
||||||
|
@ -939,6 +913,62 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ControlNodeScheduler::LinkControlArrowForEntranceActor(ActorSet *const actor_set,
|
||||||
|
const GraphCompilerInfo &graph_compiler_info) {
|
||||||
|
MS_EXCEPTION_IF_NULL(actor_set);
|
||||||
|
auto control_actor_set = actor_set->control_actors_.get();
|
||||||
|
MS_EXCEPTION_IF_NULL(control_actor_set);
|
||||||
|
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||||
|
MS_EXCEPTION_IF_NULL(parser);
|
||||||
|
|
||||||
|
// Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph
|
||||||
|
// stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the
|
||||||
|
// graph to run the next set of real parameters. The corresponding nodes of the actors that need to send the control
|
||||||
|
// arrow have been parsed in the control node parser.
|
||||||
|
for (const auto &graph_to_nodes : parser->func_graph_to_first_control_nodes_) {
|
||||||
|
// Fetch the entrance actor.
|
||||||
|
const auto &func_graph = graph_to_nodes.first;
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
|
||||||
|
auto entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
|
||||||
|
MS_EXCEPTION_IF_NULL(entrance_actor);
|
||||||
|
|
||||||
|
const auto &nodes = graph_to_nodes.second;
|
||||||
|
for (const auto &node : nodes) {
|
||||||
|
// Fetch the source actor of control arrow.
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||||
|
actor_name = func_graph->ToString() + kExitActorNameSuffix;
|
||||||
|
} else {
|
||||||
|
actor_name = GetActorName(node);
|
||||||
|
}
|
||||||
|
auto from_actor = dynamic_cast<ControlActor *>(FetchActor(actor_name));
|
||||||
|
MS_EXCEPTION_IF_NULL(from_actor);
|
||||||
|
LinkLoopBodyControlArrow(from_actor, entrance_actor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In the recursive scene, some kernel graph needs to be completed before the next set of data is sent by the
|
||||||
|
// entrance actor. At this time, it is necessary to connect a control arrow from the exit actor of the graph
|
||||||
|
// to the entrance actor.
|
||||||
|
for (const auto &func_graph_to_group_info : parser->func_graph_to_first_kernel_graphs_) {
|
||||||
|
const auto &func_graph = func_graph_to_group_info.first;
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
|
||||||
|
auto actor = FetchActor(actor_name);
|
||||||
|
MS_EXCEPTION_IF_NULL(actor);
|
||||||
|
auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
|
||||||
|
MS_EXCEPTION_IF_NULL(entrance_actor);
|
||||||
|
for (const auto &group_info : func_graph_to_group_info.second) {
|
||||||
|
MS_EXCEPTION_IF_NULL(group_info);
|
||||||
|
actor_name = group_info->group_name_ + kExitActorNameSuffix;
|
||||||
|
auto from_actor = FetchActor(actor_name);
|
||||||
|
MS_EXCEPTION_IF_NULL(from_actor);
|
||||||
|
LinkLoopBodyControlArrow(from_actor, entrance_actor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set,
|
void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set,
|
||||||
const GraphCompilerInfo &graph_compiler_info) {
|
const GraphCompilerInfo &graph_compiler_info) {
|
||||||
MS_EXCEPTION_IF_NULL(actor_set);
|
MS_EXCEPTION_IF_NULL(actor_set);
|
||||||
|
|
|
@ -58,6 +58,7 @@ class ControlNodeScheduler {
|
||||||
std::vector<StackActorPtr> *const stack_actors);
|
std::vector<StackActorPtr> *const stack_actors);
|
||||||
// Interface to link control actors.
|
// Interface to link control actors.
|
||||||
void LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
void LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||||
|
void LinkControlArrowForEntranceActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||||
void LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set);
|
void LinkBranchIDArrowForControlActor(ControlActorSet *const control_actor_set);
|
||||||
// Link all arrows between control actors.
|
// Link all arrows between control actors.
|
||||||
void LinkArrowForControlActor(ControlActorSet *const control_actor_set, const GraphCompilerInfo &graph_compiler_info);
|
void LinkArrowForControlActor(ControlActorSet *const control_actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||||
|
|
Loading…
Reference in New Issue