!45564 add reduceprod bprob cpu support dynamic function

Merge pull request !45564 from zhangdong/trans_dyn
This commit is contained in:
i-robot 2022-11-17 07:57:52 +00:00 committed by Gitee
commit 666428a765
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 46 additions and 6 deletions

View File

@ -26,7 +26,7 @@ from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
from mindspore.ops.functional import broadcast_gradient_args, reduced_shape, tuple_div
from mindspore.ops._grad.grad_base import bprop_getters
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_invert_permutation
from mindspore.ops._grad.grad_base import convert_to_tensor
from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
@ -34,6 +34,7 @@ from mindspore.ops.primitive import constexpr
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, DynamicBroadcastTo, IsSubClass
from mindspore.ops._utils.utils import is_shape_unknown, is_dim_unknown
from mindspore.ops.operations import array_ops as A
shape_op = P.Shape()
dyn_shape_op = P.TensorShape()
@ -910,6 +911,26 @@ def _invert_permutation(perm):
return tuple(out)
def _split_dyn_shape_index(x, axis):
"""Calculate reduce prod grad invert permutation."""
input_shape = dyn_shape_op(x)
rank = dyn_rank(x)
if not isinstance(axis, Tensor):
axis = Tensor(axis, dtype=mstype.int64)
reduction_indices = reshape(axis, (-1,))
reduction_indices = (reduction_indices + rank) % rank
reduced = P.Cast()(reduction_indices, mstype.int64)
start = Tensor(0, dtype=mstype.int64)
delta = Tensor(1, dtype=mstype.int64)
idx = P.Range()(start, rank, delta)
other, _ = A.ListDiff()(idx, reduced)
perm = P.Concat()((reduced, other))
reduced_num = reduce_prod(P.Cast()(P.Gather()(input_shape, reduced, 0), mstype.int64), ())
other_num = reduce_prod(P.Cast()(P.Gather()(input_shape, other, 0), mstype.int64), ())
return (reduced_num, other_num), perm
@bprop_getters.register(P.ReduceProd)
def get_bprop_reduceprod(self):
"""Grad definition for `ReduceProd` operation."""
@ -921,15 +942,27 @@ def get_bprop_reduceprod(self):
"""Grad definition for `Product` operation."""
# Expand dout to full input shape
input_shape = shape_op(x)
output_shape_kept_dims = reduced_shape(input_shape, axis)
if is_shape_unknown(input_shape):
input_shape = dyn_shape_op(x)
input_shape = P.Cast()(input_shape, ms.int64)
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int64)
else:
output_shape_kept_dims = reduced_shape(input_shape, axis)
dout = reshape(dout, output_shape_kept_dims)
tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
grad = tile(dout, tile_scaling)
# Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
pack_shape, perm = _split_shape_index(input_shape, axis)
if is_shape_unknown(shape_op(x)):
pack_shape, perm = _split_dyn_shape_index(x, axis)
else:
pack_shape, perm = _split_shape_index(shape_op(x), axis)
permuted = transpose(x, perm)
permuted_shape = shape_op(permuted)
if is_shape_unknown(permuted_shape):
permuted_shape = dyn_shape_op(permuted)
pack_shape = create_tensor_by_element(pack_shape)
reshaped = reshape(permuted, pack_shape)
# Calculate product, leaving out the current entry
@ -939,7 +972,14 @@ def get_bprop_reduceprod(self):
# Invert the transpose and reshape operations.
# Make sure to set the statically known shape information through a reshape.
out = transpose(y, _invert_permutation(perm)) * grad
if is_shape_unknown(shape_op(permuted)):
dout = DynamicBroadcastTo()(dout, input_shape)
out = transpose(y, dyn_invert_permutation(perm)) * dout
else:
tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
grad = tile(dout, tile_scaling)
out = transpose(y, _invert_permutation(perm)) * grad
dx = reshape(out, input_shape)
return dx, zeros_like(axis)