forked from mindspore-Ecosystem/mindspore
!9173 support batch size small than dev size in parallel gatherv2
From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
42f32b7c4b
|
@ -243,8 +243,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
||||
Shape index_shape = inputs_shape_.at(1);
|
||||
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
|
||||
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
||||
return FAILED;
|
||||
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
|
||||
axis_split_forward_allreduce_ = true;
|
||||
}
|
||||
|
||||
if (manual_split_) {
|
||||
|
@ -257,7 +257,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|||
|
||||
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
||||
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
|
||||
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
||||
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
|
@ -403,7 +403,8 @@ void GatherV2PInfo::InferOutputsTensorMap() {
|
|||
} else {
|
||||
// param_strategy(axis) != 1
|
||||
if (axis_ == 0) {
|
||||
if (dynamic_shape_indices_ && target_ != CPU) {
|
||||
if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
|
||||
// the output is repeat calculation
|
||||
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
|
||||
} else {
|
||||
tensor_map_out.insert(tensor_map_out.end(), 0);
|
||||
|
@ -549,15 +550,6 @@ Status GatherV2PInfo::InferGroup() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
RankList GetRankFromGroup(const Group &group) {
|
||||
RankList rank_list;
|
||||
auto device_list = group.GetDevicesList();
|
||||
for (auto &device : device_list) {
|
||||
rank_list.insert(rank_list.end(), device.rank() % 8);
|
||||
}
|
||||
return rank_list;
|
||||
}
|
||||
|
||||
Status GatherV2PInfo::InferForwardCommunication() {
|
||||
if (manual_split_) {
|
||||
return SUCCESS;
|
||||
|
@ -628,7 +620,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
||||
auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
|
||||
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
|
||||
// don't need expandim,if param_size = 1,
|
||||
// don't need expand dim, if param_size = 1
|
||||
if (inputs_shape_.at(0).size() == 1) {
|
||||
mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
|
||||
}
|
||||
|
@ -640,7 +632,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
|||
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
||||
OperatorAttrs attrs = {attr_op, attr_group};
|
||||
AnfNodePtr reduce_op;
|
||||
if (dynamic_shape_indices_) {
|
||||
if (dynamic_shape_indices_ || axis_split_forward_allreduce_) {
|
||||
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
|
||||
} else {
|
||||
reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
|
||||
|
|
|
@ -80,6 +80,7 @@ class GatherV2PInfo : public OperatorInfo {
|
|||
Group group_;
|
||||
bool manual_split_ = false;
|
||||
bool dynamic_shape_indices_ = false;
|
||||
bool axis_split_forward_allreduce_ = false; // when axis is split, use reducescatter as default in forward
|
||||
std::vector<int64_t> param_split_shapes_;
|
||||
std::vector<int64_t> index_offsets_;
|
||||
};
|
||||
|
|
|
@ -177,6 +177,19 @@ def test_gatherv2_semi_auto8():
|
|||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_forward_all_reduce():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||
strategy1 = ((8, 1), (1, 1))
|
||||
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||
net.set_auto_parallel()
|
||||
|
||||
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||
net.set_train()
|
||||
_executor.compile(net, x, y)
|
||||
|
||||
|
||||
def test_gatherv2_auto0():
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||
net = GradWrap(NetWithLoss(Net(0)))
|
||||
|
|
Loading…
Reference in New Issue