forked from mindspore-Ecosystem/mindspore
!9550 Optimize performance of PyNative grad reduce
From: @jojobugfree Reviewed-by: @kisnwang,@jjfeing Signed-off-by: @jjfeing
This commit is contained in:
commit
e4f1365495
|
@ -65,9 +65,7 @@ def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, gra
|
||||||
grad = allreduce(grad)
|
grad = allreduce(grad)
|
||||||
if mean:
|
if mean:
|
||||||
degree = F.scalar_cast(degree, F.dtype(grad))
|
degree = F.scalar_cast(degree, F.dtype(grad))
|
||||||
cast_op = P.Cast()
|
grad = F.tensor_mul(grad, F.cast(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
|
||||||
mul_op = P.Mul()
|
|
||||||
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))
|
|
||||||
return grad
|
return grad
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue