forked from mindspore-Ecosystem/mindspore
!15542 split axis and batch for gather
From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng,@stsuteng Signed-off-by: @stsuteng,@stsuteng
This commit is contained in:
commit
49d6c029a6
|
@ -255,6 +255,38 @@ Status GatherPInfo::CheckSplitAxisStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
// return true: axis is 0, and split the first dimension of parameter and the first dimension of indices
|
||||
// otherwise return false
|
||||
bool GatherPInfo::ShardBatchAndAxis(const Strategys &strategy) {
|
||||
if (axis_ != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (strategy.size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Dimensions param_strategy = strategy[0];
|
||||
Dimensions indices_strategy = strategy[1];
|
||||
if ((param_strategy.size() != 2) || (indices_strategy.size() != 2)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((param_strategy[1] != 1) || (indices_strategy[1] != 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (param_strategy[0] * indices_strategy[0] != stage_device_size_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((param_strategy[0] == stage_device_size_) || (indices_strategy[0] == stage_device_size_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
|
||||
return FAILED;
|
||||
|
@ -286,6 +318,9 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
|
||||
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
|
||||
axis_split_forward_allreduce_ = true;
|
||||
} else if (is_auto_parallel_) {
|
||||
// in auto parallel mode, this function will be called many times, so need to reset the flags
|
||||
axis_split_forward_allreduce_ = false;
|
||||
}
|
||||
|
||||
if (manual_split_) {
|
||||
|
@ -296,6 +331,17 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (ShardBatchAndAxis(strategy->GetInputDim())) {
|
||||
shard_batch_and_axis_ = true;
|
||||
axis_split_forward_allreduce_ = true;
|
||||
MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce";
|
||||
return SUCCESS;
|
||||
} else if (is_auto_parallel_) {
|
||||
// in auto parallel mode, this function will be called many times, so need to reset the flags
|
||||
shard_batch_and_axis_ = false;
|
||||
axis_split_forward_allreduce_ = false;
|
||||
}
|
||||
|
||||
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
||||
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
|
||||
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
||||
|
@ -357,6 +403,15 @@ Status GatherPInfo::InferDevMatrixShape() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (shard_batch_and_axis_) {
|
||||
dev_matrix_shape_ = {index_strategy[0], param_strategy[0]};
|
||||
// if forward use reducescatter, the dev matrix is {index_strategy[0] * param_strategy[0]}
|
||||
out_dev_matrix_shape_ = dev_matrix_shape_;
|
||||
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the dev matrix is " << dev_matrix_shape_
|
||||
<< ", out dev matrix is " << out_dev_matrix_shape_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
dev_matrix_shape_ = param_strategy;
|
||||
|
||||
// param_strategy(axis)==1,
|
||||
|
@ -473,6 +528,13 @@ Status GatherPInfo::InferTensorMap() {
|
|||
outputs_tensor_map_.push_back({-1, 1, 0});
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
if (shard_batch_and_axis_) {
|
||||
inputs_tensor_map_.push_back({0, -1}); // param
|
||||
inputs_tensor_map_.push_back({1, -1}); // indices
|
||||
outputs_tensor_map_.push_back({1, -1, -1}); // output, if forward use reducescatter, tensormap is {0, -1, -1}
|
||||
return SUCCESS;
|
||||
}
|
||||
InferInputsTensorMap();
|
||||
InferOutputsTensorMap();
|
||||
return SUCCESS;
|
||||
|
@ -516,6 +578,15 @@ Status GatherPInfo::InferBias() {
|
|||
int64_t rank = g_device_manager->rank_index_in_stage();
|
||||
auto input_shape = inputs_shape_.at(0);
|
||||
auto params_strategy = strategy_->GetInputDim().at(0);
|
||||
|
||||
if (shard_batch_and_axis_) {
|
||||
slice_size_ = input_shape[0] / params_strategy[0];
|
||||
bias_ = rank % params_strategy[0] * slice_size_;
|
||||
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the rank is " << rank << ", slice size is " << slice_size_
|
||||
<< ", bias is " << bias_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
// axis don't split
|
||||
if (params_strategy.at(axis_) == 1) {
|
||||
bias_ = 0;
|
||||
|
@ -598,6 +669,11 @@ Status GatherPInfo::InferGroup() {
|
|||
dim = dim + 1;
|
||||
}
|
||||
|
||||
if (shard_batch_and_axis_) {
|
||||
dim = 1;
|
||||
MS_LOG(INFO) << name_ << ": Sharding batch and axis, the group dim is " << dim;
|
||||
}
|
||||
|
||||
if (dev_matrix.GetDevicesAlongDim(SizeToUlong(dim), &group_devices) != SUCCESS) {
|
||||
MS_LOG(ERROR) << name_ << ": Create group failed.";
|
||||
return FAILED;
|
||||
|
|
|
@ -70,6 +70,7 @@ class GatherPInfo : public OperatorInfo {
|
|||
Status InferBias();
|
||||
Status InferOffset();
|
||||
Status InferGroup();
|
||||
bool ShardBatchAndAxis(const Strategys &strategy);
|
||||
|
||||
int64_t axis_;
|
||||
std::string target_ = DEVICE;
|
||||
|
@ -82,6 +83,7 @@ class GatherPInfo : public OperatorInfo {
|
|||
bool manual_split_ = false;
|
||||
bool dynamic_shape_indices_ = false;
|
||||
bool axis_split_forward_allreduce_ = false; // when axis is split, use reducescatter as default in forward
|
||||
bool shard_batch_and_axis_ = false;
|
||||
std::vector<int64_t> param_split_shapes_;
|
||||
std::vector<int64_t> index_offsets_;
|
||||
};
|
||||
|
|
|
@ -190,6 +190,19 @@ def test_gatherv2_forward_all_reduce():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_shard_batch_and_axis():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((4, 1), (2, 1))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_split_axis_0_repeat_calc():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=7, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((4, 1), (1, 1))
|
||||
|
|
Loading…
Reference in New Issue