optimize the multi device execution performance and communication node running order

This commit is contained in:
limingqi107 2021-07-04 09:50:28 +08:00
parent 8008103050
commit 05be63887c
2 changed files with 31 additions and 18 deletions

View File

@ -190,8 +190,8 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) { if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get()); AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
} else { } else {
MS_LOG(ERROR) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType() MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
<< ", device tensor type:" << device_tensor->DeviceType(); << ", device tensor type:" << device_tensor->DeviceType();
} }
} }
@ -840,6 +840,7 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) { void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set); MS_EXCEPTION_IF_NULL(actor_set);
std::vector<KernelActor *> auto_monad_actors; std::vector<KernelActor *> auto_monad_actors;
std::vector<CNodePtr> communication_nodes;
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = { const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad}; prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
@ -849,6 +850,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order(); auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) { for (auto &kernel : execution_order) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
communication_nodes.emplace_back(kernel);
}
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) { if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
continue; continue;
} }
@ -876,10 +880,11 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
} }
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
LinkControlArrowBySendRecvNodes(graph); LinkControlArrowBySendRecvNodes(graph);
// Link the control arrows by the communication nodes to ensure communication nodes running order.
LinkControlArrowByCommunicationNode(graph);
} }
// Link the control arrows by the communication nodes to ensure communication nodes running order.
LinkControlArrowByCommunicationNode(communication_nodes);
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) { if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
// Link the arrow by control node. // Link the arrow by control node.
LinkArrowByControlNode(graph_compiler_info, actor_set); LinkArrowByControlNode(graph_compiler_info, actor_set);
@ -1670,8 +1675,10 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
// inputs of to_allreduce_actor --> from_send_actor // inputs of to_allreduce_actor --> from_send_actor
for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) { for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name())); auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID()); if (input_actor != nullptr) {
from_send_actor->input_controls_num_++; input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
from_send_actor->input_controls_num_++;
}
} }
// from_send_actor --> from_recv_actor // from_send_actor --> from_recv_actor
@ -1703,8 +1710,10 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
// to_recv_actor --> outputs of from_allreduce_actor // to_recv_actor --> outputs of from_allreduce_actor
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) { for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name())); auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID()); if (output_actor != nullptr) {
output_actor->input_controls_num_++; to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
output_actor->input_controls_num_++;
}
} }
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be reused // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be reused
@ -1718,22 +1727,26 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
} }
} }
void GraphScheduler::LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph) { void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes) {
std::vector<CNodePtr> communication_nodes;
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
communication_nodes.emplace_back(kernel);
}
}
for (size_t i = 1; i < communication_nodes.size(); ++i) { for (size_t i = 1; i < communication_nodes.size(); ++i) {
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope())); auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope())); auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor); MS_EXCEPTION_IF_NULL(to_actor);
// Ensure communication node to execute orderly.
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++; to_actor->input_controls_num_++;
// Ensure the input actor of next communication actor is after the previous communication actor to optimize the
// execution performance in the multi device scenario.
// Using the multi stream to optimize the performance in the future.
for (auto &input_aid : to_actor->input_data_arrow_aids_) {
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
if ((input_actor != nullptr) && (from_actor != input_actor)) {
from_actor->output_control_arrows_.emplace_back(input_actor->GetAID());
input_actor->input_controls_num_++;
}
}
} }
} }

View File

@ -221,7 +221,7 @@ class GraphScheduler {
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph); void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order. // Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
void LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph); void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes);
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors); void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
// 3. The processing of linking output result arrows. // 3. The processing of linking output result arrows.