!39165 chage default fusion size for DataParallel
Merge pull request !39165 from baihuawei/allreducemaster
This commit is contained in:
commit
b2f4630083
|
@ -187,7 +187,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
|
||||||
void CommunicationOpFusion::GetAllReduceSplitSegment(const std::vector<CNodePtr> &nodes, int64_t threshold,
|
void CommunicationOpFusion::GetAllReduceSplitSegment(const std::vector<CNodePtr> &nodes, int64_t threshold,
|
||||||
std::vector<size_t> *segment_index) const {
|
std::vector<size_t> *segment_index) const {
|
||||||
MS_EXCEPTION_IF_NULL(segment_index);
|
MS_EXCEPTION_IF_NULL(segment_index);
|
||||||
if (threshold <= 0) {
|
if (threshold < 0) {
|
||||||
MS_LOG(INFO) << "Split threshold is " << threshold << ". AllReduce nodes will take default fusion strategy.";
|
MS_LOG(INFO) << "Split threshold is " << threshold << ". AllReduce nodes will take default fusion strategy.";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -476,6 +476,12 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu
|
||||||
|
|
||||||
bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
auto parallel_context = parallel::ParallelContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(parallel_context);
|
||||||
|
auto threshold = parallel_context->dp_fusion_threshold_mb();
|
||||||
|
if (threshold == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
const float input_grad_size_num = 0.0;
|
const float input_grad_size_num = 0.0;
|
||||||
const float input_grad_time_num = 0.0;
|
const float input_grad_time_num = 0.0;
|
||||||
// divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion
|
// divide candidate fusion groups with same (group,op,fusion,dtype) attrs, fusion==0 means not fusion
|
||||||
|
|
|
@ -54,7 +54,7 @@ constexpr char kFusionAuto[] = "auto";
|
||||||
constexpr char kFusionSize[] = "size";
|
constexpr char kFusionSize[] = "size";
|
||||||
constexpr char kFusionIndex[] = "index";
|
constexpr char kFusionIndex[] = "index";
|
||||||
constexpr int64_t kFusionThreshold = 64;
|
constexpr int64_t kFusionThreshold = 64;
|
||||||
constexpr int64_t kDataParallelFusionThreshold = 0;
|
constexpr int64_t kDataParallelFusionThreshold = -1;
|
||||||
|
|
||||||
class COMMON_EXPORT ParallelContext {
|
class COMMON_EXPORT ParallelContext {
|
||||||
public:
|
public:
|
||||||
|
|
Loading…
Reference in New Issue