diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index e88462c673d..0849c4fb22f 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -882,7 +882,7 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co } // Link the control arrows by the communication nodes to ensure communication nodes running order. - LinkControlArrowByCommunicationNode(communication_nodes); + LinkControlArrowByCommunicationNode(communication_nodes, graph_compiler_info); if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) { // Link the arrow by control node. @@ -1735,24 +1735,33 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph } } -void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector &communication_nodes) { +void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector &communication_nodes, + const GraphCompilerInfo &graph_compiler_info) { + const size_t kCommunicationNodesMinNum = 2; + if (communication_nodes.size() < kCommunicationNodesMinNum) { + return; + } + + // Ensure communication node to execute orderly. for (size_t i = 1; i < communication_nodes.size(); ++i) { auto from_actor = dynamic_cast(FetchActor(communication_nodes[i - 1]->fullname_with_scope())); auto to_actor = dynamic_cast(FetchActor(communication_nodes[i]->fullname_with_scope())); MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(to_actor); - // Ensure communication node to execute orderly. from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); to_actor->input_controls_num_++; + } - // Ensure the input actor of next communication actor is after the previous communication actor to optimize the - // execution performance in the multi device scenario. - // Using the multi stream to optimize the performance in the future. - for (auto &input_aid : to_actor->input_data_arrow_aids_) { - auto input_actor = dynamic_cast(FetchActor(input_aid.Name())); - if ((input_actor != nullptr) && (from_actor != input_actor)) { - from_actor->output_control_arrows_.emplace_back(input_actor->GetAID()); - input_actor->input_controls_num_++; + // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently. + // Using the multi stream to optimize the performance in the future. + for (auto &graph : graph_compiler_info.graphs_) { + auto &execution_order = graph->execution_order(); + for (size_t i = 1; i < execution_order.size(); ++i) { + auto from_actor = dynamic_cast(FetchActor(execution_order[i - 1]->fullname_with_scope())); + auto to_actor = dynamic_cast(FetchActor(execution_order[i]->fullname_with_scope())); + if ((from_actor != nullptr) && (to_actor != nullptr)) { + from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); + to_actor->input_controls_num_++; } } } diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.h b/mindspore/ccsrc/runtime/framework/graph_scheduler.h index b8927592c57..f6904939b7e 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.h @@ -221,7 +221,8 @@ class GraphScheduler { // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph. void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph); // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order. - void LinkControlArrowByCommunicationNode(const std::vector &communication_nodes); + void LinkControlArrowByCommunicationNode(const std::vector &communication_nodes, + const GraphCompilerInfo &graph_compiler_info); void LinkDeviceTensorStoreForAutoMonadActor(const std::vector &auto_monad_actors); // 3. The processing of linking output result arrows.