!17496 Set Repeat Device Matrix Direction According to the Shard

From: @huangxinjing
Reviewed-by: @stsuteng,@yangzhenzhang
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2021-06-02 19:34:46 +08:00 committed by Gitee
commit c0d9e726fb
2 changed files with 35 additions and 15 deletions

View File

@ -285,6 +285,35 @@ bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) {
return true;
}
Status GatherPInfo::SetAttribute(const StrategyPtr &strategy) {
auto param_strategy = strategy->GetInputDim().at(0);
// axis=0, index_shape(0)%param_strategy(0) must be 0
Shape index_shape = inputs_shape_.at(1);
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
axis_split_forward_allreduce_ = true;
} else if (is_auto_parallel_) {
// in auto parallel mode, this function will be called many times, so need to reset the flags
axis_split_forward_allreduce_ = false;
}
auto product_param = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
// Cast 1: If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
// parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
// and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
// can communicate normally, and dev0 to dev7 have the all parameters.
// Cast 2: If not repeated calculation(such as data parallel), need to set repeated num to the right,
// as it's easy to introduce the redistribution after or before gather operation, influencing the performance.
if (product_param == stage_device_size_ || product_param == 1) {
repeated_num_in_dev_matrix_right_ = true;
} else {
repeated_num_in_dev_matrix_right_ = false;
}
MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_;
return SUCCESS;
}
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
@ -322,16 +351,6 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
axis_split_forward_allreduce_ = false;
}
// axis=0, index_shape(0)%param_strategy(0) must be 0
Shape index_shape = inputs_shape_.at(1);
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
axis_split_forward_allreduce_ = true;
} else if (is_auto_parallel_) {
// in auto parallel mode, this function will be called many times, so need to reset the flags
axis_split_forward_allreduce_ = false;
}
if (manual_split_) {
if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) {
return FAILED;
@ -350,11 +369,11 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
// If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
// parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
// and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
// can communicate normally, and dev0 to dev7 have the all parameters.
repeated_num_in_dev_matrix_right_ = false;
// According to the strategy, set the private members.
if (SetAttribute(strategy) != SUCCESS) {
return FAILED;
}
return SUCCESS;
}

View File

@ -64,6 +64,7 @@ class GatherPInfo : public OperatorInfo {
Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit(const Strategys &strategy);
Status CheckSplitAxisStrategy(const StrategyPtr &strategy);
Status SetAttribute(const StrategyPtr &strategy);
Status GetManualSplitAttr();
Status GetManualSplitWithoutOffsetAttr();
Status ComputeReplaceOp();