!9173 support batch size small than dev size in parallel gatherv2

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2020-11-30 10:09:05 +08:00 committed by Gitee
commit 42f32b7c4b
3 changed files with 21 additions and 15 deletions

View File

@ -243,8 +243,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
// axis=0, index_shape(0)%param_strategy(0) must be 0
Shape index_shape = inputs_shape_.at(1);
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
return FAILED;
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
axis_split_forward_allreduce_ = true;
}
if (manual_split_) {
@ -257,7 +257,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
return FAILED;
}
@ -403,7 +403,8 @@ void GatherV2PInfo::InferOutputsTensorMap() {
} else {
// param_strategy(axis) != 1
if (axis_ == 0) {
if (dynamic_shape_indices_ && target_ != CPU) {
if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
// the output is repeat calculation
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
} else {
tensor_map_out.insert(tensor_map_out.end(), 0);
@ -549,15 +550,6 @@ Status GatherV2PInfo::InferGroup() {
return SUCCESS;
}
RankList GetRankFromGroup(const Group &group) {
RankList rank_list;
auto device_list = group.GetDevicesList();
for (auto &device : device_list) {
rank_list.insert(rank_list.end(), device.rank() % 8);
}
return rank_list;
}
Status GatherV2PInfo::InferForwardCommunication() {
if (manual_split_) {
return SUCCESS;
@ -628,7 +620,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
// don't need expandim,if param_size = 1,
// don't need expand dim, if param_size = 1
if (inputs_shape_.at(0).size() == 1) {
mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
}
@ -640,7 +632,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
OperatorAttrs attrs = {attr_op, attr_group};
AnfNodePtr reduce_op;
if (dynamic_shape_indices_) {
if (dynamic_shape_indices_ || axis_split_forward_allreduce_) {
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
} else {
reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});

View File

@ -80,6 +80,7 @@ class GatherV2PInfo : public OperatorInfo {
Group group_;
bool manual_split_ = false;
bool dynamic_shape_indices_ = false;
bool axis_split_forward_allreduce_ = false; // when axis is split, use reducescatter as default in forward
std::vector<int64_t> param_split_shapes_;
std::vector<int64_t> index_offsets_;
};

View File

@ -177,6 +177,19 @@ def test_gatherv2_semi_auto8():
_executor.compile(net, x, y)
def test_gatherv2_forward_all_reduce():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((2, 4, 1), (2, 4, 1))
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y)
def test_gatherv2_auto0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
net = GradWrap(NetWithLoss(Net(0)))