optimize the multi device execution performance and communication node running order

This commit is contained in:
limingqi107 2021-07-04 09:50:28 +08:00
parent 8008103050
commit 05be63887c
2 changed files with 31 additions and 18 deletions

View File

@ -190,7 +190,7 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
} else {
MS_LOG(ERROR) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
<< ", device tensor type:" << device_tensor->DeviceType();
}
}
@ -840,6 +840,7 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set);
std::vector<KernelActor *> auto_monad_actors;
std::vector<CNodePtr> communication_nodes;
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
@ -849,6 +850,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
MS_EXCEPTION_IF_NULL(graph);
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
communication_nodes.emplace_back(kernel);
}
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
continue;
}
@ -876,10 +880,11 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
}
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
LinkControlArrowBySendRecvNodes(graph);
// Link the control arrows by the communication nodes to ensure communication nodes running order.
LinkControlArrowByCommunicationNode(graph);
}
// Link the control arrows by the communication nodes to ensure communication nodes running order.
LinkControlArrowByCommunicationNode(communication_nodes);
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
// Link the arrow by control node.
LinkArrowByControlNode(graph_compiler_info, actor_set);
@ -1670,9 +1675,11 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
// inputs of to_allreduce_actor --> from_send_actor
for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
if (input_actor != nullptr) {
input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
from_send_actor->input_controls_num_++;
}
}
// from_send_actor --> from_recv_actor
from_send_actor->output_control_arrows_.emplace_back(from_recv_actor->GetAID());
@ -1703,9 +1710,11 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
// to_recv_actor --> outputs of from_allreduce_actor
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
if (output_actor != nullptr) {
to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
output_actor->input_controls_num_++;
}
}
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be reused
// only when the recv node runs finished, which is expressed by the reference count increased.
@ -1718,22 +1727,26 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
}
}
void GraphScheduler::LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph) {
std::vector<CNodePtr> communication_nodes;
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
communication_nodes.emplace_back(kernel);
}
}
void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes) {
for (size_t i = 1; i < communication_nodes.size(); ++i) {
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
// Ensure communication node to execute orderly.
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
// Ensure the input actor of next communication actor is after the previous communication actor to optimize the
// execution performance in the multi device scenario.
// Using the multi stream to optimize the performance in the future.
for (auto &input_aid : to_actor->input_data_arrow_aids_) {
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
if ((input_actor != nullptr) && (from_actor != input_actor)) {
from_actor->output_control_arrows_.emplace_back(input_actor->GetAID());
input_actor->input_controls_num_++;
}
}
}
}

View File

@ -221,7 +221,7 @@ class GraphScheduler {
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
void LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph);
void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes);
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
// 3. The processing of linking output result arrows.