!39165 chage default fusion size for DataParallel

Merge pull request !39165 from baihuawei/allreducemaster
This commit is contained in:
i-robot 2022-07-30 01:44:55 +00:00 committed by Gitee
commit b2f4630083
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 8 additions and 2 deletions

View File

@ -187,7 +187,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
void CommunicationOpFusion::GetAllReduceSplitSegment(const std::vector<CNodePtr> &nodes, int64_t threshold, void CommunicationOpFusion::GetAllReduceSplitSegment(const std::vector<CNodePtr> &nodes, int64_t threshold,
std::vector<size_t> *segment_index) const { std::vector<size_t> *segment_index) const {
MS_EXCEPTION_IF_NULL(segment_index); MS_EXCEPTION_IF_NULL(segment_index);
if (threshold <= 0) { if (threshold < 0) {
MS_LOG(INFO) << "Split threshold is " << threshold << ". AllReduce nodes will take default fusion strategy."; MS_LOG(INFO) << "Split threshold is " << threshold << ". AllReduce nodes will take default fusion strategy.";
return; return;
} }
@ -476,6 +476,12 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu
bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto threshold = parallel_context->dp_fusion_threshold_mb();
if (threshold == 0) {
return false;
}
const float input_grad_size_num = 0.0; const float input_grad_size_num = 0.0;
const float input_grad_time_num = 0.0; const float input_grad_time_num = 0.0;
// divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion // divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion

View File

@ -54,7 +54,7 @@ constexpr char kFusionAuto[] = "auto";
constexpr char kFusionSize[] = "size"; constexpr char kFusionSize[] = "size";
constexpr char kFusionIndex[] = "index"; constexpr char kFusionIndex[] = "index";
constexpr int64_t kFusionThreshold = 64; constexpr int64_t kFusionThreshold = 64;
constexpr int64_t kDataParallelFusionThreshold = 0; constexpr int64_t kDataParallelFusionThreshold = -1;
class COMMON_EXPORT ParallelContext { class COMMON_EXPORT ParallelContext {
public: public: