forked from mindspore-Ecosystem/mindspore
!9173 support batch size small than dev size in parallel gatherv2
From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsuteng
This commit is contained in:
commit
42f32b7c4b
|
@ -243,8 +243,8 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
// axis=0, index_shape(0)%param_strategy(0) must be 0
|
||||||
Shape index_shape = inputs_shape_.at(1);
|
Shape index_shape = inputs_shape_.at(1);
|
||||||
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
|
||||||
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by param_strategy(0).";
|
MS_LOG(INFO) << name_ << ": index_shape(0) can't be divided by param_strategy(0), use allreduce in forward";
|
||||||
return FAILED;
|
axis_split_forward_allreduce_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (manual_split_) {
|
if (manual_split_) {
|
||||||
|
@ -257,7 +257,7 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
|
||||||
|
|
||||||
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
|
||||||
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
|
||||||
MS_LOG(DEBUG) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -403,7 +403,8 @@ void GatherV2PInfo::InferOutputsTensorMap() {
|
||||||
} else {
|
} else {
|
||||||
// param_strategy(axis) != 1
|
// param_strategy(axis) != 1
|
||||||
if (axis_ == 0) {
|
if (axis_ == 0) {
|
||||||
if (dynamic_shape_indices_ && target_ != CPU) {
|
if ((dynamic_shape_indices_ && target_ != CPU) || axis_split_forward_allreduce_) {
|
||||||
|
// the output is repeat calculation
|
||||||
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
|
tensor_map_out.insert(tensor_map_out.end(), MAP_NONE);
|
||||||
} else {
|
} else {
|
||||||
tensor_map_out.insert(tensor_map_out.end(), 0);
|
tensor_map_out.insert(tensor_map_out.end(), 0);
|
||||||
|
@ -549,15 +550,6 @@ Status GatherV2PInfo::InferGroup() {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
RankList GetRankFromGroup(const Group &group) {
|
|
||||||
RankList rank_list;
|
|
||||||
auto device_list = group.GetDevicesList();
|
|
||||||
for (auto &device : device_list) {
|
|
||||||
rank_list.insert(rank_list.end(), device.rank() % 8);
|
|
||||||
}
|
|
||||||
return rank_list;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GatherV2PInfo::InferForwardCommunication() {
|
Status GatherV2PInfo::InferForwardCommunication() {
|
||||||
if (manual_split_) {
|
if (manual_split_) {
|
||||||
return SUCCESS;
|
return SUCCESS;
|
||||||
|
@ -628,7 +620,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype});
|
||||||
auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
|
auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt64Imm(axis_ - 1)});
|
||||||
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
|
auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims});
|
||||||
// don't need expandim,if param_size = 1,
|
// don't need expand dim, if param_size = 1
|
||||||
if (inputs_shape_.at(0).size() == 1) {
|
if (inputs_shape_.at(0).size() == 1) {
|
||||||
mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
|
mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast});
|
||||||
}
|
}
|
||||||
|
@ -640,7 +632,7 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||||
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name()));
|
||||||
OperatorAttrs attrs = {attr_op, attr_group};
|
OperatorAttrs attrs = {attr_op, attr_group};
|
||||||
AnfNodePtr reduce_op;
|
AnfNodePtr reduce_op;
|
||||||
if (dynamic_shape_indices_) {
|
if (dynamic_shape_indices_ || axis_split_forward_allreduce_) {
|
||||||
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
|
reduce_op = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), mul});
|
||||||
} else {
|
} else {
|
||||||
reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
|
reduce_op = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul});
|
||||||
|
|
|
@ -80,6 +80,7 @@ class GatherV2PInfo : public OperatorInfo {
|
||||||
Group group_;
|
Group group_;
|
||||||
bool manual_split_ = false;
|
bool manual_split_ = false;
|
||||||
bool dynamic_shape_indices_ = false;
|
bool dynamic_shape_indices_ = false;
|
||||||
|
bool axis_split_forward_allreduce_ = false; // when axis is split, use reducescatter as default in forward
|
||||||
std::vector<int64_t> param_split_shapes_;
|
std::vector<int64_t> param_split_shapes_;
|
||||||
std::vector<int64_t> index_offsets_;
|
std::vector<int64_t> index_offsets_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -177,6 +177,19 @@ def test_gatherv2_semi_auto8():
|
||||||
_executor.compile(net, x, y)
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gatherv2_forward_all_reduce():
|
||||||
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
|
||||||
|
strategy1 = ((8, 1), (1, 1))
|
||||||
|
strategy2 = ((2, 4, 1), (2, 4, 1))
|
||||||
|
net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64])))
|
||||||
|
net.set_auto_parallel()
|
||||||
|
|
||||||
|
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
|
||||||
|
y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32)
|
||||||
|
net.set_train()
|
||||||
|
_executor.compile(net, x, y)
|
||||||
|
|
||||||
|
|
||||||
def test_gatherv2_auto0():
|
def test_gatherv2_auto0():
|
||||||
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel")
|
||||||
net = GradWrap(NetWithLoss(Net(0)))
|
net = GradWrap(NetWithLoss(Net(0)))
|
||||||
|
|
Loading…
Reference in New Issue