diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index a46614c4257..3c5196915ee 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -542,7 +542,8 @@ void ExceptionIfHasCommunicationOp(const std::vector &all_nodes) { } } -void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution) { +void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution, + const NodeUsersMap &node_users_map) { MS_EXCEPTION_IF_NULL(cnode->func_graph()); FuncGraphManagerPtr manager = cnode->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -559,9 +560,8 @@ void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tenso } // Find Redistribution next_nodes - auto node_users_map = manager->node_users(); std::vector, int>> next_nodes; - RedistributionNextNode(cnode, manager, &node_users_map, -1, &next_nodes); + RedistributionNextNode(cnode, manager, node_users_map, -1, &next_nodes); // Insert Redistribution nodes between pre_nodes and next_nodes for (auto &pre_node : pre_nodes) { @@ -2476,12 +2476,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vectornode_users(); for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto cnode = node->cast(); if (IsValueNode(cnode->input(0))) { - StepRedistribution(cnode, tensor_redistribution); + StepRedistribution(cnode, tensor_redistribution, node_users_map); continue; } // the make_tuple is parallel care node, but it may have not operator info @@ -2502,7 +2503,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes); -void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution); +void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution, + const NodeUsersMap &node_users_map); void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc index 5c224643354..35afd1b237b 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.cc @@ -247,12 +247,14 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) { return tuple_index_value->cast()->value(); } -void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map, - int64_t get_item_index, +void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, + const NodeUsersMap &node_users_map, int64_t get_item_index, std::vector, int>> *next_nodes) { MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node_users_map); - auto node_set = (*node_users_map)[node]; + if (node_users_map.count(node) == 0) { + return; + } + auto node_set = node_users_map.at(node); for (auto &node_pair : node_set) { auto use_cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(use_cnode); @@ -275,6 +277,8 @@ void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &m } if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data()) { next_nodes->push_back(std::make_pair(node_pair, get_item_index)); + } else if (use_cnode->input(0)->isa()) { + continue; } else { // search recursively RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, next_nodes); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h index b2a7677d238..8acf5489dce 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel_utils.h @@ -57,8 +57,8 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode); AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node = nullptr); void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, std::vector *pre_nodes); -void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map, - int64_t get_item_index, +void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, + const NodeUsersMap &node_users_map, int64_t get_item_index, std::vector, int>> *next_nodes); // for specific scenarios