From e29f5c96cb9013d2f728f3cf0f4f971e9fd1badd Mon Sep 17 00:00:00 2001 From: Ziyan Date: Wed, 2 Dec 2020 19:31:41 +0800 Subject: [PATCH] enable_allgather_fusion --- .../ccsrc/frontend/parallel/step_parallel.cc | 2 +- mindspore/ops/_grad/grad_comm_ops.py | 25 ++++++++++++++----- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index edb62f55218..5d80e3ee489 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1408,7 +1408,7 @@ void InsertAllGatherOp(const std::string &group, const std::pair(allgather->input(0)); auto attrs = prim->attrs(); // enable fusion flag later when it's supported in backend - attrs["fusion"] = MakeValue(0); + attrs["fusion"] = MakeValue(1); prim->SetAttrs(attrs); } diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 684ed98b119..88e8641e4ca 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -16,6 +16,7 @@ """Generate bprop for comm ops""" import mindspore.common.dtype as mstype from mindspore.ops import functional as F +from mindspore.communication import get_rank, get_group_size from .. import operations as P from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like @@ -116,15 +117,27 @@ def get_bprop_broad_cast(self): @bprop_getters.register(AllGather) def get_bprop_all_gather(self): """Generate bprop for AllGather""" - all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group) fusion = self.get_attr_dict()["fusion"] - all_gather_grad.add_prim_attr("fusion", fusion) - if self.instance_name: - instance_name = "grad_" + self.instance_name - all_gather_grad.set_prim_instance_name(instance_name) + if fusion == 0: + reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group) + if self.instance_name: + instance_name = "grad_" + self.instance_name + reduce_scatter.set_prim_instance_name(instance_name) + else: + all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1) + if self.instance_name: + instance_name = "grad_" + self.instance_name + all_reduce.set_prim_instance_name(instance_name) + rank = get_rank(self.group) + dev_num = get_group_size(self.group) + split = P.Split(output_num=dev_num) def bprop(x, out, dout): - dx = all_gather_grad(dout) + if fusion == 0: + dx = reduce_scatter(dout) + else: + grad = all_reduce(dout) + dx = split(grad)[rank] return (dx,) return bprop