forked from mindspore-Ecosystem/mindspore
!17496 Set Repeat Device Matrix Direction According to the Shard
From: @huangxinjing Reviewed-by: @stsuteng,@yangzhenzhang Signed-off-by: @stsuteng
This commit is contained in:
commit
c0d9e726fb
|
@ -285,6 +285,35 @@ bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) {
|
|||
return true;
|
||||
}
|
||||
|
||||
Status GatherPInfo::SetAttribute(const StrategyPtr &strategy) {
|
||||
auto param_strategy = strategy->GetInputDim().at(0);
|
||||
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
||||
Shape index_shape = inputs_shape_.at(1);
|
||||
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
|
||||
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
|
||||
axis_split_forward_allreduce_ = true;
|
||||
} else if (is_auto_parallel_) {
|
||||
// in auto parallel mode, this function will be called many times, so need to reset the flags
|
||||
axis_split_forward_allreduce_ = false;
|
||||
}
|
||||
|
||||
auto product_param = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>());
|
||||
// Cast 1: If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
|
||||
// parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
|
||||
// and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
|
||||
// can communicate normally, and dev0 to dev7 have the all parameters.
|
||||
// Cast 2: If not repeated calculation(such as data parallel), need to set repeated num to the right,
|
||||
// as it's easy to introduce the redistribution after or before gather operation, influencing the performance.
|
||||
if (product_param == stage_device_size_ || product_param == 1) {
|
||||
repeated_num_in_dev_matrix_right_ = true;
|
||||
} else {
|
||||
repeated_num_in_dev_matrix_right_ = false;
|
||||
}
|
||||
MS_LOG(INFO) << "Set repeated_num_in_dev_matrix_right for gather to " << repeated_num_in_dev_matrix_right_;
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
|
@ -322,16 +351,6 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
axis_split_forward_allreduce_ = false;
|
||||
}
|
||||
|
||||
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
||||
Shape index_shape = inputs_shape_.at(1);
|
||||
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
|
||||
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
|
||||
axis_split_forward_allreduce_ = true;
|
||||
} else if (is_auto_parallel_) {
|
||||
// in auto parallel mode, this function will be called many times, so need to reset the flags
|
||||
axis_split_forward_allreduce_ = false;
|
||||
}
|
||||
|
||||
if (manual_split_) {
|
||||
if (CheckManualSplit(strategy->GetInputDim()) != SUCCESS) {
|
||||
return FAILED;
|
||||
|
@ -350,11 +369,11 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
// If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
|
||||
// parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
|
||||
// and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
|
||||
// can communicate normally, and dev0 to dev7 have the all parameters.
|
||||
repeated_num_in_dev_matrix_right_ = false;
|
||||
// According to the strategy, set the private members.
|
||||
if (SetAttribute(strategy) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -64,6 +64,7 @@ class GatherPInfo : public OperatorInfo {
|
|||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
Status CheckManualSplit(const Strategys &strategy);
|
||||
Status CheckSplitAxisStrategy(const StrategyPtr &strategy);
|
||||
Status SetAttribute(const StrategyPtr &strategy);
|
||||
Status GetManualSplitAttr();
|
||||
Status GetManualSplitWithoutOffsetAttr();
|
||||
Status ComputeReplaceOp();
|
||||
|
|
Loading…
Reference in New Issue