From 1529d544b93a9c86dedaf680ed8c3fd86d837272 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Tue, 17 Nov 2020 09:41:57 +0800 Subject: [PATCH] add allreduce prod --- mindspore/ops/_grad/grad_comm_ops.py | 11 +++++++++-- mindspore/ops/operations/comm_ops.py | 2 -- tests/ut/python/communication/test_comm.py | 1 + 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 3f5dc390b4a..48bc0b940ce 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -37,11 +37,18 @@ def get_bprop_all_reduce(self): equal = P.Equal() cast = P.Cast() mul = P.Mul() + div = P.RealDiv() dtype = P.DType() if self.op == ReduceOp.PROD: - raise RuntimeError("The bprop of ReduceOp.PROD is not supported yet.") - if self.op == ReduceOp.SUM: + + def bprop(x, out, dout): + dy1 = mul(dout, out) + dy2 = all_reduce_grad(dy1) + dx = div(dy2, x) + return (dx,) + + elif self.op == ReduceOp.SUM: def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 9ce3a69f0ee..19451fc237f 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -92,8 +92,6 @@ class AllReduce(PrimitiveWithInfer): def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): if not isinstance(op, type(ReduceOp.SUM)): raise TypeError("The operation of AllReduce should be str.") - if op == ReduceOp.PROD: - raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.") if not isinstance(_get_group(group), str): raise TypeError("The group of AllReduce should be str.") self.op = op diff --git a/tests/ut/python/communication/test_comm.py b/tests/ut/python/communication/test_comm.py index 7df6149e7b6..4fb5b5c3b0e 100644 --- a/tests/ut/python/communication/test_comm.py +++ b/tests/ut/python/communication/test_comm.py @@ -138,6 +138,7 @@ def test_allreduce(): run_allreduce(ReduceOp.SUM) run_allreduce(ReduceOp.MAX) run_allreduce(ReduceOp.MIN) + run_allreduce(ReduceOp.PROD) def test_allgather():