From 7481779e957e598e612403700bbe1bf84b48a766 Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Tue, 29 Jun 2021 17:12:40 +0800 Subject: [PATCH] link auto monad for gather actor --- .../ccsrc/backend/session/kernel_graph.cc | 11 +++++- .../runtime/framework/control_node_parser.cc | 38 +++++++++++++++++++ .../runtime/framework/control_node_parser.h | 5 +++ .../runtime/framework/graph_scheduler.cc | 28 +++++++++----- .../ccsrc/runtime/framework/graph_scheduler.h | 4 +- 5 files changed, 73 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 808aa960b11..630d7903482 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1142,8 +1142,15 @@ FuncGraphPtr KernelGraph::GetFuncGraph() { if (front_backend_anf_map_.empty()) { return nullptr; } - const auto &front_node = front_backend_anf_map_.begin()->first; - return front_node->func_graph(); + + for (const auto &front_backend_anf : front_backend_anf_map_) { + const auto &front_node = front_backend_anf.first; + const auto &func_graph = front_node->func_graph(); + if (func_graph != nullptr) { + return func_graph; + } + } + return nullptr; } void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output, diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index f537e5a7c85..45a254a2c3a 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -396,6 +396,20 @@ std::vector FetchOutputBySwitchNode(const AnfNodePtr &switch_node, s return outputs; } + +// Recursive interface, get the real kernel that UpdateState node depends on. +AnfNodePtr FetchSourceNodeByAutoMonad(const AnfNodePtr &node) { + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { + const auto &cnode = node->cast(); + const auto &inputs = cnode->inputs(); + if (inputs.size() <= kUpdateStateRealInput) { + MS_LOG(EXCEPTION) << "Invalid updatestate node:" << AnfAlgo::GetNodeDebugString(node); + } + + return FetchSourceNodeByAutoMonad(inputs[kUpdateStateRealInput]); + } + return node; +} } // namespace // Return true if the node has Ref abstract. @@ -615,6 +629,8 @@ void ControlNodeParser::Parse(const std::vector &control_nodes, cons FetchCallInputKernelGraph(graphs, device_contexts); FetchBackendInputNode(graphs, device_contexts, real_to_formal_front_parameters, formal_to_real_front_parameters); + + FetchAutoMonadNode(control_nodes); } std::vector ControlNodeParser::GetBackendInputByParameter(const AnfNodePtr ¶meter) { @@ -1370,5 +1386,27 @@ void ControlNodeParser::FetchBackendInputNode(const std::vector formal_to_real_parameters_[parameter_pair.first].push_back({parameter_pair.second.first, 0}); } } + +void ControlNodeParser::FetchAutoMonadNode(const std::vector &control_nodes) { + for (const auto &control_node : control_nodes) { + const auto &cnode = control_node->cast(); + const auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Invalid control node:" << AnfAlgo::GetNodeDebugString(control_node); + } + + if (inputs[0]->isa() && IsValueNode(inputs[0])) { + for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) { + if (AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState)) { + const auto &node = FetchSourceNodeByAutoMonad(inputs[i]); + const auto &iter = front_to_backend_kernels_.find(node); + if (iter != front_to_backend_kernels_.end()) { + kernel_to_call_nodes_[iter->second.first] = control_node; + } + } + } + } + } +} } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index 79f5b1f2f0e..d658230345a 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -181,6 +181,8 @@ class ControlNodeParser { void FetchBackendOutputByFrontOutput(const AnfNodePtr &front_output, std::set *call_nodes, std::set *switch_nodes, std::set *results); + // Get the dependency between kernel and call node in auto monad. + void FetchAutoMonadNode(const std::vector &control_nodes); // The front to backend parameters is used to build and link the host data source actor in the control flow scenario. FrontToBackendNodeWithContext front_to_backend_parameters_; @@ -226,6 +228,9 @@ class ControlNodeParser { // Root funcgraph and its parameters. FuncGraphPtr root_func_graph_; std::vector root_graph_parameters_; + + // The dependency between kernel and call node in auto monad. + std::unordered_map kernel_to_call_nodes_; }; using ControlNodeParserPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index c406aac41cf..e0f947cc514 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -1953,8 +1953,7 @@ void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compi LinkBranchArrowForGatherActor(graph_compiler_info, actor_set); - LinkControlArrowForGatherActor(&(actor_set->gather_actors_), &(actor_set->kernel_actors_), - actor_set->loop_count_actor_.get(), graph_compiler_info.graphs_, + LinkControlArrowForGatherActor(&(actor_set->kernel_actors_), graph_compiler_info.graphs_, graph_compiler_info.control_node_parser_); LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(), @@ -2163,14 +2162,9 @@ void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_ } } -void GraphScheduler::LinkControlArrowForGatherActor(std::vector *from_actors, - std::vector *kernel_actors, - LoopCountActor *to_actor, const std::vector &graphs, +void GraphScheduler::LinkControlArrowForGatherActor(std::vector *kernel_actors, + const std::vector &graphs, const ControlNodeParserPtr &parser) { - if (from_actors == nullptr || to_actor == nullptr) { - return; - } - // Link control arrow to kernel actor. for (size_t i = 0; i < graphs.size(); ++i) { const auto &kernel_graph = graphs[i]; @@ -2226,6 +2220,22 @@ void GraphScheduler::LinkControlArrowForGatherActor(std::vector } } } + + // Link input auto monad control arrow from kernel actor to gather actor. + const auto &monad_nodes = parser->kernel_to_call_nodes_; + for (const auto node_pair : monad_nodes) { + const auto &kernel_actor_name = node_pair.first->fullname_with_scope(); + const auto &gather_actor_name = node_pair.second->DebugString(); + auto kernel_op_actor = FetchActor(kernel_actor_name); + auto gather_op_actor = FetchActor(gather_actor_name); + if (kernel_op_actor == nullptr || gather_op_actor == nullptr) { + continue; + } + auto kernel_actor = dynamic_cast(kernel_op_actor); + auto gather_actor = dynamic_cast(gather_op_actor); + kernel_actor->output_control_arrows_.emplace_back(gather_actor->GetAID()); + gather_actor->input_controls_num_++; + } } void GraphScheduler::LinkControlArrowForSwitchActor(std::vector *switch_actors, diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.h b/mindspore/ccsrc/runtime/framework/graph_scheduler.h index 7d02f63416f..8234d1a3c39 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.h @@ -242,8 +242,8 @@ class GraphScheduler { const size_t to_index); void LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index, OpActor *to_actor, const size_t to_index, const size_t branch_index = SIZE_MAX); - void LinkControlArrowForGatherActor(std::vector *from_actors, - std::vector *kernel_actors, LoopCountActor *to_actor, + + void LinkControlArrowForGatherActor(std::vector *kernel_actors, const std::vector &graphs, const ControlNodeParserPtr &parser); void LinkControlArrowForSwitchActor(std::vector *switch_actors, LoopCountActor *to_actor,