!39165 chage default fusion size for DataParallel
Merge pull request !39165 from baihuawei/allreducemaster
This commit is contained in:
commit
b2f4630083
|
@ -187,7 +187,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
|||
void CommunicationOpFusion::GetAllReduceSplitSegment(const std::vector<CNodePtr> &nodes, int64_t threshold,
|
||||
std::vector<size_t> *segment_index) const {
|
||||
MS_EXCEPTION_IF_NULL(segment_index);
|
||||
if (threshold <= 0) {
|
||||
if (threshold < 0) {
|
||||
MS_LOG(INFO) << "Split threshold is " << threshold << ". AllReduce nodes will take default fusion strategy.";
|
||||
return;
|
||||
}
|
||||
|
@ -476,6 +476,12 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu
|
|||
|
||||
bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||
auto threshold = parallel_context->dp_fusion_threshold_mb();
|
||||
if (threshold == 0) {
|
||||
return false;
|
||||
}
|
||||
const float input_grad_size_num = 0.0;
|
||||
const float input_grad_time_num = 0.0;
|
||||
// divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion
|
||||
|
|
|
@ -54,7 +54,7 @@ constexpr char kFusionAuto[] = "auto";
|
|||
constexpr char kFusionSize[] = "size";
|
||||
constexpr char kFusionIndex[] = "index";
|
||||
constexpr int64_t kFusionThreshold = 64;
|
||||
constexpr int64_t kDataParallelFusionThreshold = 0;
|
||||
constexpr int64_t kDataParallelFusionThreshold = -1;
|
||||
|
||||
class COMMON_EXPORT ParallelContext {
|
||||
public:
|
||||
|
|
Loading…
Reference in New Issue