diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index c31695d62cd..6c013cde1b2 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -135,23 +135,6 @@ std::string GetNodeGroup(const AnfNodePtr &node) { } return ""; } - -bool NeedOptimizeCommOp(const AnfNodePtr &node, std::map *optimized_comm_group) { - MS_EXCEPTION_IF_NULL(optimized_comm_group); - auto node_group = GetNodeGroup(node); - if (node_group.find(kSyncBnGroup) != string::npos) { - return false; - } - auto node_name = AnfAlgo::GetCNodeName(node); - auto iter = optimized_comm_group->find(node_name); - if (iter == optimized_comm_group->end()) { - (*optimized_comm_group)[node_name] = node_group; - return true; - } else if (iter->second == node_group) { - return true; - } - return false; -} } // namespace AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const { @@ -188,7 +171,6 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue active_nodes; for (const auto &output_edge : it->second) { @@ -209,7 +191,9 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queuefind(next_node) == visited_nodes->end()) { (void)visited_nodes->insert(next_node); bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node); - if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) { + if (AnfAlgo::CheckPrimitiveType(next_node, prim::kPrimLoad)) { + EnqueueActiveNodes(next_node, visit_queue, visited_nodes); + } else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) { MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString(); visit_queue->push(next_node); } else { @@ -217,10 +201,7 @@ void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queueDebugString(); visit_queue->push(active_node); } } @@ -233,7 +214,7 @@ void KernelGraph::SetExecOrderByDefault() { std::queue zero_input_nodes; std::stack delay_comm_stack; std::queue communication_descendants; - std::map optimized_comm_group; + std::string optimized_comm_group; while (!seed_nodes.empty() || !delay_comm_stack.empty()) { // seed nodes first, then delay comm nodes if (seed_nodes.empty()) { @@ -262,9 +243,13 @@ void KernelGraph::SetExecOrderByDefault() { } // delay execute comm ops that need optimize bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node); - bool optimize_comm = is_fused_comm; - if (optimize_comm) { - optimize_comm = NeedOptimizeCommOp(node, &optimized_comm_group); + bool optimize_comm = false; + if (is_fused_comm && optimized_comm_group.empty()) { + auto node_group = GetNodeGroup(node); + if (node_group.find(kSyncBnGroup) == string::npos) { + optimized_comm_group = node_group; + optimize_comm = true; + } } if (optimize_comm) { while (!delay_comm_stack.empty()) { diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index be7b07063db..66ef32d5ea1 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -166,11 +166,11 @@ void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector & if (!group.empty() && group.find(TypeIdLabel(kNumberTypeFloat)) == std::string::npos && group.find(TypeIdLabel(kNumberTypeFloat16)) == std::string::npos && group.find(TypeIdLabel(kNumberTypeFloat32)) == std::string::npos) { - all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat)] = sizes; - all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes; - all_reduce_fusion_split_indices_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes; + all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat)] = sizes; + all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat16)] = sizes; + all_reduce_fusion_split_sizes_[group + TypeIdLabel(kNumberTypeFloat32)] = sizes; } else { - all_reduce_fusion_split_indices_[group] = sizes; + all_reduce_fusion_split_sizes_[group] = sizes; } }