!14699 [AutoParallel]Fix autoparallel gatherV2 bug

From: @lichen666
Reviewed-by: @kisnwang,@zhunaipan
Signed-off-by: @zhunaipan
This commit is contained in:
mindspore-ci-bot 2021-04-09 11:01:56 +08:00 committed by Gitee
commit 447a3a6fc2
2 changed files with 36 additions and 33 deletions

View File

@ -228,6 +228,33 @@ Status GatherPInfo::CheckManualSplit(const Strategys &strategy) {
return SUCCESS; return SUCCESS;
} }
Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
auto param_strategy = strategy->GetInputDim().at(0);
auto index_strategy = strategy->GetInputDim().at(1);
// param_strategy(axis) != 1, index can't be split
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) {
MS_LOG(DEBUG) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split.";
return FAILED;
}
// param_strategy(axis) != 1, and axis != 0, don't support repeated calc
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ != 0)) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED;
}
if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ == 0)) {
if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) {
MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation";
}
return SUCCESS;
}
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED; return FAILED;
@ -275,29 +302,10 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
// param_strategy(axis) != 1, index can't be split if (CheckSplitAxisStrategy(strategy) != SUCCESS) {
auto index_strategy = strategy->GetInputDim().at(1);
auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int64_t>());
if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) {
MS_LOG(DEBUG) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split.";
return FAILED; return FAILED;
} }
// param_strategy(axis) != 1, and axis != 0, don't support repeated calc
auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int64_t>());
if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ != 0)) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc.";
return FAILED;
}
if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ == 0)) {
if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) {
MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation";
}
// If repeated calculation, need to set repeated num to the left of dev-matrix. For example, // If repeated calculation, need to set repeated num to the left of dev-matrix. For example,
// parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16, // parameter strategy is [8, 1], indices strategy is [1, 1], dev num is 16,
// and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they // and dev_matrix is [2, 1, 8, 1, 1], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they
@ -351,17 +359,14 @@ Status GatherPInfo::InferDevMatrixShape() {
dev_matrix_shape_ = param_strategy; dev_matrix_shape_ = param_strategy;
// param_strategy(axis)!=1, // param_strategy(axis)==1,
if (param_strategy.at(LongToSize(axis_)) != 1) { if (param_strategy.at(LongToSize(axis_)) == 1) {
std::reverse(dev_matrix_shape_.begin(), dev_matrix_shape_.end());
} else {
dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end()); dev_matrix_shape_.insert(dev_matrix_shape_.end(), index_strategy.begin(), index_strategy.end());
} }
// infer out dev_matrix_shape // infer out dev_matrix_shape
// axis!=0, split axis // axis!=0, split axis
if (axis_ != 0 && param_strategy.at(LongToSize(axis_)) != 1) { if (axis_ != 0 && param_strategy.at(LongToSize(axis_)) != 1) {
out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(LongToSize(axis_)));
for (size_t i = 1; i < param_strategy.size(); ++i) { for (size_t i = 1; i < param_strategy.size(); ++i) {
if (i == LongToSize(axis_)) { if (i == LongToSize(axis_)) {
out_dev_matrix_shape_.push_back(1); out_dev_matrix_shape_.push_back(1);
@ -369,6 +374,7 @@ Status GatherPInfo::InferDevMatrixShape() {
out_dev_matrix_shape_.push_back(param_strategy.at(i)); out_dev_matrix_shape_.push_back(param_strategy.at(i));
} }
} }
out_dev_matrix_shape_.push_back(param_strategy.at(0) * param_strategy.at(LongToSize(axis_)));
} else { } else {
out_dev_matrix_shape_ = dev_matrix_shape_; out_dev_matrix_shape_ = dev_matrix_shape_;
} }
@ -398,7 +404,7 @@ void GatherPInfo::InferInputsTensorMap() {
if (param_strategy.at(LongToSize(axis_)) != 1) { if (param_strategy.at(LongToSize(axis_)) != 1) {
tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE); tensor_map_index.insert(tensor_map_index.begin(), index_size, MAP_NONE);
for (size_t i = 0; i < param_size; ++i) { for (size_t i = 0; i < param_size; ++i) {
tensor_map_params.push_back(SizeToLong(i)); tensor_map_params.push_back(SizeToLong(param_size - i - 1));
} }
} else { } else {
// param_strategy(axis) == 1 // param_strategy(axis) == 1
@ -438,11 +444,11 @@ void GatherPInfo::InferOutputsTensorMap() {
// the output is repeat calculation // the output is repeat calculation
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE); tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
} else { } else {
tensor_map_out.insert(tensor_map_out.end(), 0); tensor_map_out.insert(tensor_map_out.end(), param_size - 1);
} }
tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE); tensor_map_out.insert(tensor_map_out.end(), index_size - 1, MAP_NONE);
for (size_t i = 1; i < param_size; ++i) { for (size_t i = 1; i < param_size; ++i) {
tensor_map_out.push_back(i); tensor_map_out.push_back(param_size - 1 - i);
} }
} else { } else {
for (size_t i = 0; i < param_size; ++i) { for (size_t i = 0; i < param_size; ++i) {
@ -452,7 +458,7 @@ void GatherPInfo::InferOutputsTensorMap() {
if (i == 0 && dynamic_shape_indices_ && target_ != CPU) { if (i == 0 && dynamic_shape_indices_ && target_ != CPU) {
tensor_map_out.push_back(MAP_NONE); tensor_map_out.push_back(MAP_NONE);
} }
tensor_map_out.push_back(SizeToLong(param_size - i - 1)); tensor_map_out.push_back(SizeToLong(i));
} }
} }
} }
@ -581,11 +587,7 @@ Status GatherPInfo::InferOffset() {
} }
Status GatherPInfo::InferGroup() { Status GatherPInfo::InferGroup() {
auto param_strategy = strategy_->GetInputDim().at(0);
size_t dim = LongToSize(axis_); size_t dim = LongToSize(axis_);
if (param_strategy.at(LongToSize(axis_)) != 1 && inputs_shape_.at(0).size() == 2) {
dim = (axis_ + 1) % 2;
}
int64_t rank = g_device_manager->global_rank(); int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_); DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);

View File

@ -63,6 +63,7 @@ class GatherPInfo : public OperatorInfo {
Status ComputeReplaceGraph(const CNodePtr &cnode); Status ComputeReplaceGraph(const CNodePtr &cnode);
Status CheckManualSplit(const Strategys &strategy); Status CheckManualSplit(const Strategys &strategy);
Status CheckSplitAxisStrategy(const StrategyPtr &strategy);
Status GetManualSplitAttr(); Status GetManualSplitAttr();
Status GetManualSplitWithoutOffsetAttr(); Status GetManualSplitWithoutOffsetAttr();
Status ComputeReplaceOp(); Status ComputeReplaceOp();