forked from mindspore-Ecosystem/mindspore
!10743 enable gradients mean in opt shard
From: @gong_zi_yan Reviewed-by: @stsuteng,@yao_yf,@kisnwang Signed-off-by: @stsuteng
This commit is contained in:
commit
b07dd76246
|
@ -1390,9 +1390,16 @@ static void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodeP
|
|||
InsertNode(op, cnode, res.second, parameter, graph, PARALLEL_OPTIMIZER_ALLGATHER);
|
||||
allgather = cnode->input(res.second)->cast<CNodePtr>();
|
||||
}
|
||||
// add fusion flag
|
||||
MS_EXCEPTION_IF_NULL(allgather);
|
||||
// add fusion flag
|
||||
AddCommOpFusionType(allgather, parameter);
|
||||
// add gradients mean
|
||||
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
||||
auto attrs = prim->attrs();
|
||||
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
|
||||
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
|
||||
attrs["mean_flag"] = MakeValue<bool>(mean_flag);
|
||||
prim->SetAttrs(attrs);
|
||||
}
|
||||
|
||||
static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter,
|
||||
|
|
|
@ -134,6 +134,8 @@ def get_bprop_all_gather(self):
|
|||
rank = get_rank(self.group)
|
||||
dev_num = get_group_size(self.group)
|
||||
split = P.Split(output_num=dev_num)
|
||||
mean_flag = self.get_attr_dict()["mean_flag"]
|
||||
scale = 1/self.rank_size
|
||||
|
||||
def bprop(x, out, dout):
|
||||
if fusion == 0:
|
||||
|
@ -141,6 +143,8 @@ def get_bprop_all_gather(self):
|
|||
else:
|
||||
grad = all_reduce(dout)
|
||||
dx = split(grad)[rank]
|
||||
if mean_flag:
|
||||
dx = F.tensor_mul(dx, scale)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
|
Loading…
Reference in New Issue