optimize the multi device execution performance and communication node running order
This commit is contained in:
parent
8008103050
commit
05be63887c
|
@ -190,7 +190,7 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
|
|||
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
|
||||
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
|
||||
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
|
||||
<< ", device tensor type:" << device_tensor->DeviceType();
|
||||
}
|
||||
}
|
||||
|
@ -840,6 +840,7 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
|
|||
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
std::vector<KernelActor *> auto_monad_actors;
|
||||
std::vector<CNodePtr> communication_nodes;
|
||||
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
|
||||
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
|
||||
|
||||
|
@ -849,6 +850,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto execution_order = graph->execution_order();
|
||||
for (auto &kernel : execution_order) {
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
communication_nodes.emplace_back(kernel);
|
||||
}
|
||||
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
|
||||
continue;
|
||||
}
|
||||
|
@ -876,10 +880,11 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
|||
}
|
||||
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
|
||||
LinkControlArrowBySendRecvNodes(graph);
|
||||
// Link the control arrows by the communication nodes to ensure communication nodes running order.
|
||||
LinkControlArrowByCommunicationNode(graph);
|
||||
}
|
||||
|
||||
// Link the control arrows by the communication nodes to ensure communication nodes running order.
|
||||
LinkControlArrowByCommunicationNode(communication_nodes);
|
||||
|
||||
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||
// Link the arrow by control node.
|
||||
LinkArrowByControlNode(graph_compiler_info, actor_set);
|
||||
|
@ -1670,9 +1675,11 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
|
|||
// inputs of to_allreduce_actor --> from_send_actor
|
||||
for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
|
||||
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
|
||||
if (input_actor != nullptr) {
|
||||
input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
|
||||
from_send_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
// from_send_actor --> from_recv_actor
|
||||
from_send_actor->output_control_arrows_.emplace_back(from_recv_actor->GetAID());
|
||||
|
@ -1703,9 +1710,11 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
|
|||
// to_recv_actor --> outputs of from_allreduce_actor
|
||||
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
|
||||
auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
|
||||
if (output_actor != nullptr) {
|
||||
to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
|
||||
output_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
||||
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be reused
|
||||
// only when the recv node runs finished, which is expressed by the reference count increased.
|
||||
|
@ -1718,22 +1727,26 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
|
|||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph) {
|
||||
std::vector<CNodePtr> communication_nodes;
|
||||
auto execution_order = graph->execution_order();
|
||||
for (auto &kernel : execution_order) {
|
||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||
communication_nodes.emplace_back(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes) {
|
||||
for (size_t i = 1; i < communication_nodes.size(); ++i) {
|
||||
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
|
||||
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
// Ensure communication node to execute orderly.
|
||||
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
|
||||
to_actor->input_controls_num_++;
|
||||
|
||||
// Ensure the input actor of next communication actor is after the previous communication actor to optimize the
|
||||
// execution performance in the multi device scenario.
|
||||
// Using the multi stream to optimize the performance in the future.
|
||||
for (auto &input_aid : to_actor->input_data_arrow_aids_) {
|
||||
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
|
||||
if ((input_actor != nullptr) && (from_actor != input_actor)) {
|
||||
from_actor->output_control_arrows_.emplace_back(input_actor->GetAID());
|
||||
input_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -221,7 +221,7 @@ class GraphScheduler {
|
|||
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
|
||||
void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
|
||||
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
|
||||
void LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph);
|
||||
void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes);
|
||||
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
|
||||
|
||||
// 3. The processing of linking output result arrows.
|
||||
|
|
Loading…
Reference in New Issue