optimize the multi device execution performance and communication node running order
This commit is contained in:
parent
8008103050
commit
05be63887c
|
@ -190,8 +190,8 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &
|
||||||
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
|
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
|
||||||
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
|
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(ERROR) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
|
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
|
||||||
<< ", device tensor type:" << device_tensor->DeviceType();
|
<< ", device tensor type:" << device_tensor->DeviceType();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -840,6 +840,7 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
|
||||||
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
|
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
|
||||||
MS_EXCEPTION_IF_NULL(actor_set);
|
MS_EXCEPTION_IF_NULL(actor_set);
|
||||||
std::vector<KernelActor *> auto_monad_actors;
|
std::vector<KernelActor *> auto_monad_actors;
|
||||||
|
std::vector<CNodePtr> communication_nodes;
|
||||||
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
|
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
|
||||||
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
|
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
|
||||||
|
|
||||||
|
@ -849,6 +850,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto execution_order = graph->execution_order();
|
auto execution_order = graph->execution_order();
|
||||||
for (auto &kernel : execution_order) {
|
for (auto &kernel : execution_order) {
|
||||||
|
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
||||||
|
communication_nodes.emplace_back(kernel);
|
||||||
|
}
|
||||||
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
|
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -876,10 +880,11 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
|
||||||
}
|
}
|
||||||
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
|
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
|
||||||
LinkControlArrowBySendRecvNodes(graph);
|
LinkControlArrowBySendRecvNodes(graph);
|
||||||
// Link the control arrows by the communication nodes to ensure communication nodes running order.
|
|
||||||
LinkControlArrowByCommunicationNode(graph);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Link the control arrows by the communication nodes to ensure communication nodes running order.
|
||||||
|
LinkControlArrowByCommunicationNode(communication_nodes);
|
||||||
|
|
||||||
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
|
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
|
||||||
// Link the arrow by control node.
|
// Link the arrow by control node.
|
||||||
LinkArrowByControlNode(graph_compiler_info, actor_set);
|
LinkArrowByControlNode(graph_compiler_info, actor_set);
|
||||||
|
@ -1670,8 +1675,10 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
|
||||||
// inputs of to_allreduce_actor --> from_send_actor
|
// inputs of to_allreduce_actor --> from_send_actor
|
||||||
for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
|
for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
|
||||||
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
|
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
|
||||||
input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
|
if (input_actor != nullptr) {
|
||||||
from_send_actor->input_controls_num_++;
|
input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
|
||||||
|
from_send_actor->input_controls_num_++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// from_send_actor --> from_recv_actor
|
// from_send_actor --> from_recv_actor
|
||||||
|
@ -1703,8 +1710,10 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
|
||||||
// to_recv_actor --> outputs of from_allreduce_actor
|
// to_recv_actor --> outputs of from_allreduce_actor
|
||||||
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
|
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
|
||||||
auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
|
auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
|
||||||
to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
|
if (output_actor != nullptr) {
|
||||||
output_actor->input_controls_num_++;
|
to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
|
||||||
|
output_actor->input_controls_num_++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be reused
|
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be reused
|
||||||
|
@ -1718,22 +1727,26 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphScheduler::LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph) {
|
void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes) {
|
||||||
std::vector<CNodePtr> communication_nodes;
|
|
||||||
auto execution_order = graph->execution_order();
|
|
||||||
for (auto &kernel : execution_order) {
|
|
||||||
if (AnfAlgo::IsCommunicationOp(kernel)) {
|
|
||||||
communication_nodes.emplace_back(kernel);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 1; i < communication_nodes.size(); ++i) {
|
for (size_t i = 1; i < communication_nodes.size(); ++i) {
|
||||||
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
|
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
|
||||||
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
|
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
|
||||||
MS_EXCEPTION_IF_NULL(from_actor);
|
MS_EXCEPTION_IF_NULL(from_actor);
|
||||||
MS_EXCEPTION_IF_NULL(to_actor);
|
MS_EXCEPTION_IF_NULL(to_actor);
|
||||||
|
// Ensure communication node to execute orderly.
|
||||||
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
|
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
|
||||||
to_actor->input_controls_num_++;
|
to_actor->input_controls_num_++;
|
||||||
|
|
||||||
|
// Ensure the input actor of next communication actor is after the previous communication actor to optimize the
|
||||||
|
// execution performance in the multi device scenario.
|
||||||
|
// Using the multi stream to optimize the performance in the future.
|
||||||
|
for (auto &input_aid : to_actor->input_data_arrow_aids_) {
|
||||||
|
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
|
||||||
|
if ((input_actor != nullptr) && (from_actor != input_actor)) {
|
||||||
|
from_actor->output_control_arrows_.emplace_back(input_actor->GetAID());
|
||||||
|
input_actor->input_controls_num_++;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -221,7 +221,7 @@ class GraphScheduler {
|
||||||
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
|
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
|
||||||
void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
|
void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
|
||||||
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
|
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
|
||||||
void LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph);
|
void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes);
|
||||||
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
|
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
|
||||||
|
|
||||||
// 3. The processing of linking output result arrows.
|
// 3. The processing of linking output result arrows.
|
||||||
|
|
Loading…
Reference in New Issue