forked from mindspore-Ecosystem/mindspore
add allreduce prod
This commit is contained in:
parent
b4773004fb
commit
1529d544b9
|
@ -37,11 +37,18 @@ def get_bprop_all_reduce(self):
|
|||
equal = P.Equal()
|
||||
cast = P.Cast()
|
||||
mul = P.Mul()
|
||||
div = P.RealDiv()
|
||||
dtype = P.DType()
|
||||
|
||||
if self.op == ReduceOp.PROD:
|
||||
raise RuntimeError("The bprop of ReduceOp.PROD is not supported yet.")
|
||||
if self.op == ReduceOp.SUM:
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dy1 = mul(dout, out)
|
||||
dy2 = all_reduce_grad(dy1)
|
||||
dx = div(dy2, x)
|
||||
return (dx,)
|
||||
|
||||
elif self.op == ReduceOp.SUM:
|
||||
|
||||
def bprop(x, out, dout):
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
|
|
|
@ -92,8 +92,6 @@ class AllReduce(PrimitiveWithInfer):
|
|||
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
|
||||
if not isinstance(op, type(ReduceOp.SUM)):
|
||||
raise TypeError("The operation of AllReduce should be str.")
|
||||
if op == ReduceOp.PROD:
|
||||
raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.")
|
||||
if not isinstance(_get_group(group), str):
|
||||
raise TypeError("The group of AllReduce should be str.")
|
||||
self.op = op
|
||||
|
|
|
@ -138,6 +138,7 @@ def test_allreduce():
|
|||
run_allreduce(ReduceOp.SUM)
|
||||
run_allreduce(ReduceOp.MAX)
|
||||
run_allreduce(ReduceOp.MIN)
|
||||
run_allreduce(ReduceOp.PROD)
|
||||
|
||||
|
||||
def test_allgather():
|
||||
|
|
Loading…
Reference in New Issue