diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index 6e828d71dc1..6317664f6dc 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -228,6 +228,33 @@ Status GatherPInfo::CheckManualSplit(const Strategys &strategy) { return SUCCESS; } +Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) { + auto param_strategy = strategy->GetInputDim().at(0); + auto index_strategy = strategy->GetInputDim().at(1); + // param_strategy(axis) != 1, index can't be split + auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); + if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) { + MS_LOG(DEBUG) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split."; + return FAILED; + } + + // param_strategy(axis) != 1, and axis != 0, don't support repeated calc + auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); + if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ != 0)) { + MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; + return FAILED; + } + + if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ == 0)) { + if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) { + MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation"; + } + return SUCCESS; +} + Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { return FAILED; @@ -275,29 +302,10 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - // param_strategy(axis) != 1, index can't be split - auto index_strategy = strategy->GetInputDim().at(1); - auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); - if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) { - MS_LOG(DEBUG) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split."; + if (CheckSplitAxisStrategy(strategy) != SUCCESS) { return FAILED; } - // param_strategy(axis) != 1, and axis != 0, don't support repeated calc - auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); - if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ != 0)) { - MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; - return FAILED; - } - - if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ == 0)) { - if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) { - MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc."; - return FAILED; - } - MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation"; - } - // If repeated calculation, need to set repeated num to the left of dev-matrix. For example, // parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16, // and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they @@ -351,17 +359,14 @@ Status GatherPInfo::InferDevMatrixShape() { dev_matrix_shape_ = param_strategy; - // param_strategy(axis)!=1, - if (param_strategy.at(LongToSize(axis_)) != 1) { - std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end()); - } else { + // param_strategy(axis)==1, + if (param_strategy.at(LongToSize(axis_)) == 1) { dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end()); } // infer out dev_matrix_shape // axis!=0, split axis if (axis_ != 0 && param_strategy.at(LongToSize(axis_)) != 1) { - out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(LongToSize(axis_))); for (size_t i = 1; i < param_strategy.size(); ++i) { if (i == LongToSize(axis_)) { out_dev_matrix_shape_.push_back(1); @@ -369,6 +374,7 @@ Status GatherPInfo::InferDevMatrixShape() { out_dev_matrix_shape_.push_back(param_strategy.at(i)); } } + out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(LongToSize(axis_))); } else { out_dev_matrix_shape_ = dev_matrix_shape_; } @@ -398,7 +404,7 @@ void GatherPInfo::InferInputsTensorMap() { if (param_strategy.at(LongToSize(axis_)) != 1) { tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE); for (size_t i = 0; i < param_size; ++i) { - tensor_map_params.push_back(SizeToLong(i)); + tensor_map_params.push_back(SizeToLong(param_size - i - 1)); } } else { // param_strategy(axis) == 1 @@ -438,11 +444,11 @@ void GatherPInfo::InferOutputsTensorMap() { // the output is repeat calculation tensor_map_out.insert(tensor_map_out.end(), MAP_NONE); } else { - tensor_map_out.insert(tensor_map_out.end(), 0); + tensor_map_out.insert(tensor_map_out.end(), param_size - 1); } tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE); for (size_t i = 1; i < param_size; ++i) { - tensor_map_out.push_back(i); + tensor_map_out.push_back(param_size - 1 - i); } } else { for (size_t i = 0; i < param_size; ++i) { @@ -452,7 +458,7 @@ void GatherPInfo::InferOutputsTensorMap() { if (i == 0 && dynamic_shape_indices_ && target_ != CPU) { tensor_map_out.push_back(MAP_NONE); } - tensor_map_out.push_back(SizeToLong(param_size - i - 1)); + tensor_map_out.push_back(SizeToLong(i)); } } } @@ -581,11 +587,7 @@ Status GatherPInfo::InferOffset() { } Status GatherPInfo::InferGroup() { - auto param_strategy = strategy_->GetInputDim().at(0); size_t dim = LongToSize(axis_); - if (param_strategy.at(LongToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) { - dim = (axis_ + 1) % 2; - } int64_t rank = g_device_manager->global_rank(); DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index 7d17313eb06..28871810664 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -63,6 +63,7 @@ class GatherPInfo : public OperatorInfo { Status ComputeReplaceGraph(const CNodePtr &cnode); Status CheckManualSplit(const Strategys &strategy); + Status CheckSplitAxisStrategy(const StrategyPtr &strategy); Status GetManualSplitAttr(); Status GetManualSplitWithoutOffsetAttr(); Status ComputeReplaceOp();