!15542 split axis and batch for gather

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng,@stsuteng
Signed-off-by: @stsuteng,@stsuteng
This commit is contained in:
mindspore-ci-bot 2021-04-25 19:33:09 +08:00 committed by Gitee
commit 49d6c029a6
3 changed files with 91 additions and 0 deletions

View File

@ -255,6 +255,38 @@ Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
// return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
// otherwise return false
bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) {
if (axis_ != 0) {
return false;
}
if (strategy.size() != 2) {
return false;
}
Dimensions param_strategy = strategy[0];
Dimensions indices_strategy = strategy[1];
if ((param_strategy.size() != 2) || (indices_strategy.size() != 2)) {
return false;
}
if ((param_strategy[1] != 1) || (indices_strategy[1] != 1)) {
return false;
}
if (param_strategy[0] * indices_strategy[0] != stage_device_size_) {
return false;
}
if ((param_strategy[0] == stage_device_size_) || (indices_strategy[0] == stage_device_size_)) {
return false;
}
return true;
}
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED;
@ -286,6 +318,9 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
axis_split_forward_allreduce_ = true;
} else if (is_auto_parallel_) {
// in auto parallel mode, this function will be called many times, so need to reset the flags
axis_split_forward_allreduce_ = false;
}
if (manual_split_) {
@ -296,6 +331,17 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}
if (ShardBatchAndAxis(strategy->GetInputDim())) {
shard_batch_and_axis_ = true;
axis_split_forward_allreduce_ = true;
MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce";
return SUCCESS;
} else if (is_auto_parallel_) {
// in auto parallel mode, this function will be called many times, so need to reset the flags
shard_batch_and_axis_ = false;
axis_split_forward_allreduce_ = false;
}
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
@ -357,6 +403,15 @@ Status GatherPInfo::InferDevMatrixShape() {
return SUCCESS;
}
if (shard_batch_and_axis_) {
dev_matrix_shape_ = {index_strategy[0], param_strategy[0]};
// if forward use reducescatter, the dev matrix is {index_strategy[0] * param_strategy[0]}
out_dev_matrix_shape_ = dev_matrix_shape_;
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the dev matrix is " << dev_matrix_shape_
<< ", out dev matrix is " << out_dev_matrix_shape_;
return SUCCESS;
}
dev_matrix_shape_ = param_strategy;
// param_strategy(axis)==1,
@ -473,6 +528,13 @@ Status GatherPInfo::InferTensorMap() {
outputs_tensor_map_.push_back({-1, 1, 0});
return SUCCESS;
}
if (shard_batch_and_axis_) {
inputs_tensor_map_.push_back({0, -1}); // param
inputs_tensor_map_.push_back({1, -1}); // indices
outputs_tensor_map_.push_back({1, -1, -1}); // output, if forward use reducescatter, tensormap is {0, -1, -1}
return SUCCESS;
}
InferInputsTensorMap();
InferOutputsTensorMap();
return SUCCESS;
@ -516,6 +578,15 @@ Status GatherPInfo::InferBias() {
int64_t rank = g_device_manager->rank_index_in_stage();
auto input_shape = inputs_shape_.at(0);
auto params_strategy = strategy_->GetInputDim().at(0);
if (shard_batch_and_axis_) {
slice_size_ = input_shape[0] / params_strategy[0];
bias_ = rank % params_strategy[0] * slice_size_;
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the rank is " << rank << ", slice size is " << slice_size_
<< ", bias is " << bias_;
return SUCCESS;
}
// axis don't split
if (params_strategy.at(axis_) == 1) {
bias_ = 0;
@ -598,6 +669,11 @@ Status GatherPInfo::InferGroup() {
dim = dim + 1;
}
if (shard_batch_and_axis_) {
dim = 1;
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the group dim is " << dim;
}
if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group failed.";
return FAILED;

View File

@ -70,6 +70,7 @@ class GatherPInfo : public OperatorInfo {
Status InferBias();
Status InferOffset();
Status InferGroup();
bool ShardBatchAndAxis(const Strategys &strategy);
int64_t axis_;
std::string target_ = DEVICE;
@ -82,6 +83,7 @@ class GatherPInfo : public OperatorInfo {
bool manual_split_ = false;
bool dynamic_shape_indices_ = false;
bool axis_split_forward_allreduce_ = false; // when axis is split, use reducescatter as default in forward
bool shard_batch_and_axis_ = false;
std::vector<int64_t> param_split_shapes_;
std::vector<int64_t> index_offsets_;
};

View File

@ -190,6 +190,19 @@ def test_gatherv2_forward_all_reduce():
_executor.compile(net, x, y)
def test_gatherv2_shard_batch_and_axis():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((4, 1), (2, 1))
strategy2 = ((2, 4, 1), (2, 4, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y)
def test_gatherv2_split_axis_0_repeat_calc():
context.set_auto_parallel_context(device_num=8, global_rank=7, parallel_mode="semi_auto_parallel")
strategy1 = ((4, 1), (1, 1))