diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index 444134b1771..e50e6c659b9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -243,8 +243,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { // axis=0, index_shape(0)%param_strategy(0) must be 0 Shape index_shape = inputs_shape_.at(1); if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) { - MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; - return FAILED; + MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward"; + axis_split_forward_allreduce_ = true; } if (manual_split_) { @@ -257,7 +257,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) { - MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; + MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; return FAILED; } @@ -403,7 +403,8 @@ void GatherV2PInfo::InferOutputsTensorMap() { } else { // param_strategy(axis) != 1 if (axis_ == 0) { - if (dynamic_shape_indices_ && target_ != CPU) { + if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) { + // the output is repeat calculation tensor_map_out.insert(tensor_map_out.end(), MAP_NONE); } else { tensor_map_out.insert(tensor_map_out.end(), 0); @@ -549,15 +550,6 @@ Status GatherV2PInfo::InferGroup() { return SUCCESS; } -RankList GetRankFromGroup(const Group &group) { - RankList rank_list; - auto device_list = group.GetDevicesList(); - for (auto &device : device_list) { - rank_list.insert(rank_list.end(), device.rank() % 8); - } - return rank_list; -} - Status GatherV2PInfo::InferForwardCommunication() { if (manual_split_) { return SUCCESS; @@ -628,7 +620,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)}); auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims}); - // don't need expandim,if param_size = 1, + // don't need expand dim, if param_size = 1 if (inputs_shape_.at(0).size() == 1) { mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast}); } @@ -640,7 +632,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); OperatorAttrs attrs = {attr_op, attr_group}; AnfNodePtr reduce_op; - if (dynamic_shape_indices_) { + if (dynamic_shape_indices_ || axis_split_forward_allreduce_) { reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul}); } else { reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h index 1f09170df5a..39fbe446d24 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.h @@ -80,6 +80,7 @@ class GatherV2PInfo : public OperatorInfo { Group group_; bool manual_split_ = false; bool dynamic_shape_indices_ = false; + bool axis_split_forward_allreduce_ = false; // when axis is split, use reducescatter as default in forward std::vector param_split_shapes_; std::vector index_offsets_; }; diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index b950d8b43a5..15090365091 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -177,6 +177,19 @@ def test_gatherv2_semi_auto8(): _executor.compile(net, x, y) +def test_gatherv2_forward_all_reduce(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + strategy1 = ((8, 1), (1, 1)) + strategy2 = ((2, 4, 1), (2, 4, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64]))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32) + net.set_train() + _executor.compile(net, x, y) + + def test_gatherv2_auto0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") net = GradWrap(NetWithLoss(Net(0)))