!10743 enable gradients mean in opt shard

From: @gong_zi_yan
Reviewed-by: @stsuteng,@yao_yf,@kisnwang
Signed-off-by: @stsuteng
This commit is contained in:
mindspore-ci-bot 2020-12-29 21:36:35 +08:00 committed by Gitee
commit b07dd76246
2 changed files with 12 additions and 1 deletions

View File

@ -1390,9 +1390,16 @@ static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodeP
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
allgather = cnode->input(res.second)->cast<CNodePtr>();
}
// add fusion flag
MS_EXCEPTION_IF_NULL(allgather);
// add fusion flag
AddCommOpFusionType(allgather, parameter);
// add gradients mean
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
auto attrs = prim->attrs();
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
attrs["mean_flag"] = MakeValue<bool>(mean_flag);
prim->SetAttrs(attrs);
}
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr &parameter,

View File

@ -134,6 +134,8 @@ def get_bprop_all_gather(self):
rank = get_rank(self.group)
dev_num = get_group_size(self.group)
split = P.Split(output_num=dev_num)
mean_flag = self.get_attr_dict()["mean_flag"]
scale = 1/self.rank_size
def bprop(x, out, dout):
if fusion == 0:
@ -141,6 +143,8 @@ def get_bprop_all_gather(self):
else:
grad = all_reduce(dout)
dx = split(grad)[rank]
if mean_flag:
dx = F.tensor_mul(dx, scale)
return (dx,)
return bprop