From 05be63887c55f4c339d3084d19983e34952f336c Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Sun, 4 Jul 2021 09:50:28 +0800 Subject: [PATCH] optimize the multi device execution performance and communication node running order --- .../runtime/framework/graph_scheduler.cc | 47 ++++++++++++------- .../ccsrc/runtime/framework/graph_scheduler.h | 2 +- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index e04ddfd9afc..6835a87e43c 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -190,8 +190,8 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr & if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) { AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get()); } else { - MS_LOG(ERROR) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType() - << ", device tensor type:" << device_tensor->DeviceType(); + MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType() + << ", device tensor type:" << device_tensor->DeviceType(); } } @@ -840,6 +840,7 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) { MS_EXCEPTION_IF_NULL(actor_set); std::vector auto_monad_actors; + std::vector communication_nodes; const std::unordered_set auto_monad_prims = { prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad}; @@ -849,6 +850,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co MS_EXCEPTION_IF_NULL(graph); auto execution_order = graph->execution_order(); for (auto &kernel : execution_order) { + if (AnfAlgo::IsCommunicationOp(kernel)) { + communication_nodes.emplace_back(kernel); + } if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) { continue; } @@ -876,10 +880,11 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co } // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph. LinkControlArrowBySendRecvNodes(graph); - // Link the control arrows by the communication nodes to ensure communication nodes running order. - LinkControlArrowByCommunicationNode(graph); } + // Link the control arrows by the communication nodes to ensure communication nodes running order. + LinkControlArrowByCommunicationNode(communication_nodes); + if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) { // Link the arrow by control node. LinkArrowByControlNode(graph_compiler_info, actor_set); @@ -1670,8 +1675,10 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph // inputs of to_allreduce_actor --> from_send_actor for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) { auto input_actor = dynamic_cast(FetchActor(input_aid.Name())); - input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID()); - from_send_actor->input_controls_num_++; + if (input_actor != nullptr) { + input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID()); + from_send_actor->input_controls_num_++; + } } // from_send_actor --> from_recv_actor @@ -1703,8 +1710,10 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph // to_recv_actor --> outputs of from_allreduce_actor for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) { auto output_actor = dynamic_cast(FetchActor(output_data_arrow->to_op_id_.Name())); - to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID()); - output_actor->input_controls_num_++; + if (output_actor != nullptr) { + to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID()); + output_actor->input_controls_num_++; + } } // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be reused @@ -1718,22 +1727,26 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph } } -void GraphScheduler::LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph) { - std::vector communication_nodes; - auto execution_order = graph->execution_order(); - for (auto &kernel : execution_order) { - if (AnfAlgo::IsCommunicationOp(kernel)) { - communication_nodes.emplace_back(kernel); - } - } - +void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector &communication_nodes) { for (size_t i = 1; i < communication_nodes.size(); ++i) { auto from_actor = dynamic_cast(FetchActor(communication_nodes[i - 1]->fullname_with_scope())); auto to_actor = dynamic_cast(FetchActor(communication_nodes[i]->fullname_with_scope())); MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(to_actor); + // Ensure communication node to execute orderly. from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); to_actor->input_controls_num_++; + + // Ensure the input actor of next communication actor is after the previous communication actor to optimize the + // execution performance in the multi device scenario. + // Using the multi stream to optimize the performance in the future. + for (auto &input_aid : to_actor->input_data_arrow_aids_) { + auto input_actor = dynamic_cast(FetchActor(input_aid.Name())); + if ((input_actor != nullptr) && (from_actor != input_actor)) { + from_actor->output_control_arrows_.emplace_back(input_actor->GetAID()); + input_actor->input_controls_num_++; + } + } } } diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.h b/mindspore/ccsrc/runtime/framework/graph_scheduler.h index 24b238cc784..b8927592c57 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.h @@ -221,7 +221,7 @@ class GraphScheduler { // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph. void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph); // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order. - void LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph); + void LinkControlArrowByCommunicationNode(const std::vector &communication_nodes); void LinkDeviceTensorStoreForAutoMonadActor(const std::vector &auto_monad_actors); // 3. The processing of linking output result arrows.