forked from mindspore-Ecosystem/mindspore
!9396 enable allgather fusion
From: @gong_zi_yan Reviewed-by: @stsuteng,@yangzhenzhang,@kisnwang Signed-off-by: @stsuteng,@kisnwang
This commit is contained in:
commit
6b9e402790
|
@ -1408,7 +1408,7 @@ void InsertAllGatherOp(const std::string &group, const std::pair<AnfNodePtr, int
|
||||||
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
auto prim = GetValueNode<PrimitivePtr>(allgather->input(0));
|
||||||
auto attrs = prim->attrs();
|
auto attrs = prim->attrs();
|
||||||
// enable fusion flag later when it's supported in backend
|
// enable fusion flag later when it's supported in backend
|
||||||
attrs["fusion"] = MakeValue<int64_t>(0);
|
attrs["fusion"] = MakeValue<int64_t>(1);
|
||||||
prim->SetAttrs(attrs);
|
prim->SetAttrs(attrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
"""Generate bprop for comm ops"""
|
"""Generate bprop for comm ops"""
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
from mindspore.communication import get_rank, get_group_size
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from ...common.tensor import RowTensor
|
from ...common.tensor import RowTensor
|
||||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
|
@ -117,15 +118,27 @@ def get_bprop_broad_cast(self):
|
||||||
@bprop_getters.register(AllGather)
|
@bprop_getters.register(AllGather)
|
||||||
def get_bprop_all_gather(self):
|
def get_bprop_all_gather(self):
|
||||||
"""Generate bprop for AllGather"""
|
"""Generate bprop for AllGather"""
|
||||||
all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group)
|
|
||||||
fusion = self.get_attr_dict()["fusion"]
|
fusion = self.get_attr_dict()["fusion"]
|
||||||
all_gather_grad.add_prim_attr("fusion", fusion)
|
if fusion == 0:
|
||||||
if self.instance_name:
|
reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group)
|
||||||
instance_name = "grad_" + self.instance_name
|
if self.instance_name:
|
||||||
all_gather_grad.set_prim_instance_name(instance_name)
|
instance_name = "grad_" + self.instance_name
|
||||||
|
reduce_scatter.set_prim_instance_name(instance_name)
|
||||||
|
else:
|
||||||
|
all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", 1)
|
||||||
|
if self.instance_name:
|
||||||
|
instance_name = "grad_" + self.instance_name
|
||||||
|
all_reduce.set_prim_instance_name(instance_name)
|
||||||
|
rank = get_rank(self.group)
|
||||||
|
dev_num = get_group_size(self.group)
|
||||||
|
split = P.Split(output_num=dev_num)
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
dx = all_gather_grad(dout)
|
if fusion == 0:
|
||||||
|
dx = reduce_scatter(dout)
|
||||||
|
else:
|
||||||
|
grad = all_reduce(dout)
|
||||||
|
dx = split(grad)[rank]
|
||||||
return (dx,)
|
return (dx,)
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
Loading…
Reference in New Issue