add allreduce prod

This commit is contained in:
yao_yf 2020-11-17 09:41:57 +08:00
parent b4773004fb
commit 1529d544b9
3 changed files with 10 additions and 4 deletions

View File

@ -37,11 +37,18 @@ def get_bprop_all_reduce(self):
equal = P.Equal()
cast = P.Cast()
mul = P.Mul()
div = P.RealDiv()
dtype = P.DType()
if self.op == ReduceOp.PROD:
raise RuntimeError("The bprop of ReduceOp.PROD is not supported yet.")
if self.op == ReduceOp.SUM:
def bprop(x, out, dout):
dy1 = mul(dout, out)
dy2 = all_reduce_grad(dy1)
dx = div(dy2, x)
return (dx,)
elif self.op == ReduceOp.SUM:
def bprop(x, out, dout):
if F.issubclass_(F.typeof(dout), mstype.tensor):

View File

@ -92,8 +92,6 @@ class AllReduce(PrimitiveWithInfer):
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP):
if not isinstance(op, type(ReduceOp.SUM)):
raise TypeError("The operation of AllReduce should be str.")
if op == ReduceOp.PROD:
raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.")
if not isinstance(_get_group(group), str):
raise TypeError("The group of AllReduce should be str.")
self.op = op

View File

@ -138,6 +138,7 @@ def test_allreduce():
run_allreduce(ReduceOp.SUM)
run_allreduce(ReduceOp.MAX)
run_allreduce(ReduceOp.MIN)
run_allreduce(ReduceOp.PROD)
def test_allgather():