!19414 fix the control ring bug

Merge pull request !19414 from limingqi107/bug_fix2
This commit is contained in:
i-robot 2021-07-05 17:55:44 +00:00 committed by Gitee
commit ede0139b06
2 changed files with 22 additions and 12 deletions

View File

@ -882,7 +882,7 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
}
// Link the control arrows by the communication nodes to ensure communication nodes running order.
LinkControlArrowByCommunicationNode(communication_nodes);
LinkControlArrowByCommunicationNode(communication_nodes, graph_compiler_info);
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
// Link the arrow by control node.
@ -1735,24 +1735,33 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
}
}
void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes) {
void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
const GraphCompilerInfo &graph_compiler_info) {
const size_t kCommunicationNodesMinNum = 2;
if (communication_nodes.size() < kCommunicationNodesMinNum) {
return;
}
// Ensure communication node to execute orderly.
for (size_t i = 1; i < communication_nodes.size(); ++i) {
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
// Ensure communication node to execute orderly.
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
}
// Ensure the input actor of next communication actor is after the previous communication actor to optimize the
// execution performance in the multi device scenario.
// Using the multi stream to optimize the performance in the future.
for (auto &input_aid : to_actor->input_data_arrow_aids_) {
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
if ((input_actor != nullptr) && (from_actor != input_actor)) {
from_actor->output_control_arrows_.emplace_back(input_actor->GetAID());
input_actor->input_controls_num_++;
// Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
// Using the multi stream to optimize the performance in the future.
for (auto &graph : graph_compiler_info.graphs_) {
auto &execution_order = graph->execution_order();
for (size_t i = 1; i < execution_order.size(); ++i) {
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i - 1]->fullname_with_scope()));
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i]->fullname_with_scope()));
if ((from_actor != nullptr) && (to_actor != nullptr)) {
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
}
}
}

View File

@ -221,7 +221,8 @@ class GraphScheduler {
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes);
void LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
const GraphCompilerInfo &graph_compiler_info);
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);
// 3. The processing of linking output result arrows.