diff --git a/mindspore/ccsrc/backend/common/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/common/pass/communication_op_fusion.cc index c2441b00063..a3accae2704 100644 --- a/mindspore/ccsrc/backend/common/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/common/pass/communication_op_fusion.cc @@ -187,7 +187,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic void CommunicationOpFusion::GetAllReduceSplitSegment(const std::vector &nodes, int64_t threshold, std::vector *segment_index) const { MS_EXCEPTION_IF_NULL(segment_index); - if (threshold <= 0) { + if (threshold < 0) { MS_LOG(INFO) << "Split threshold is " << threshold << ". AllReduce nodes will take default fusion strategy."; return; } @@ -476,6 +476,12 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); + auto parallel_context = parallel::ParallelContext::GetInstance(); + MS_EXCEPTION_IF_NULL(parallel_context); + auto threshold = parallel_context->dp_fusion_threshold_mb(); + if (threshold == 0) { + return false; + } const float input_grad_size_num = 0.0; const float input_grad_time_num = 0.0; // divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion diff --git a/mindspore/ccsrc/include/common/utils/parallel_context.h b/mindspore/ccsrc/include/common/utils/parallel_context.h index 49bf1a1a609..cb4c11dcb39 100644 --- a/mindspore/ccsrc/include/common/utils/parallel_context.h +++ b/mindspore/ccsrc/include/common/utils/parallel_context.h @@ -54,7 +54,7 @@ constexpr char kFusionAuto[] = "auto"; constexpr char kFusionSize[] = "size"; constexpr char kFusionIndex[] = "index"; constexpr int64_t kFusionThreshold = 64; -constexpr int64_t kDataParallelFusionThreshold = 0; +constexpr int64_t kDataParallelFusionThreshold = -1; class COMMON_EXPORT ParallelContext { public: