From 98566ddc07f3ef6ecc8004b1ad3629b2e1e435ae Mon Sep 17 00:00:00 2001 From: Ziyan Date: Tue, 29 Dec 2020 10:44:14 +0800 Subject: [PATCH] enable gradients mean in opt shard --- mindspore/ccsrc/frontend/parallel/step_parallel.cc | 9 ++++++++- mindspore/ops/_grad/grad_comm_ops.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 3e69ddc74d2..ce97cd37633 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1390,9 +1390,16 @@ static void InsertAllGatherOp(const std::string &group, const std::pairinput(res.second)->cast(); } - // add fusion flag MS_EXCEPTION_IF_NULL(allgather); + // add fusion flag AddCommOpFusionType(allgather, parameter); + // add gradients mean + auto prim = GetValueNode(allgather->input(0)); + auto attrs = prim->attrs(); + MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); + bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); + attrs["mean_flag"] = MakeValue(mean_flag); + prim->SetAttrs(attrs); } static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 82a7e0e0f00..17655bf321a 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -134,6 +134,8 @@ def get_bprop_all_gather(self): rank = get_rank(self.group) dev_num = get_group_size(self.group) split = P.Split(output_num=dev_num) + mean_flag = self.get_attr_dict()["mean_flag"] + scale = 1/self.rank_size def bprop(x, out, dout): if fusion == 0: @@ -141,6 +143,8 @@ def get_bprop_all_gather(self): else: grad = all_reduce(dout) dx = split(grad)[rank] + if mean_flag: + dx = F.tensor_mul(dx, scale) return (dx,) return bprop