!45564 add reduceprod bprob cpu support dynamic function
Merge pull request !45564 from zhangdong/trans_dyn
This commit is contained in:
commit
666428a765
|
@ -26,7 +26,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from mindspore.ops.functional import broadcast_gradient_args, reduced_shape, tuple_div
|
||||
from mindspore.ops._grad.grad_base import bprop_getters
|
||||
from mindspore.ops._grad.grad_base import bprop_getters, create_tensor_by_element, dyn_invert_permutation
|
||||
from mindspore.ops._grad.grad_base import convert_to_tensor
|
||||
from mindspore.ops._grad.grad_base import sum_grad_reduce_axis, dyn_fill, dyn_rank
|
||||
from mindspore.ops._grad.grad_base import dyn_ones, dyn_rank_1d
|
||||
|
@ -34,6 +34,7 @@ from mindspore.ops.primitive import constexpr
|
|||
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from mindspore.ops.operations._inner_ops import DynamicBroadcastGradientArgs, DynamicBroadcastTo, IsSubClass
|
||||
from mindspore.ops._utils.utils import is_shape_unknown, is_dim_unknown
|
||||
from mindspore.ops.operations import array_ops as A
|
||||
|
||||
shape_op = P.Shape()
|
||||
dyn_shape_op = P.TensorShape()
|
||||
|
@ -910,6 +911,26 @@ def _invert_permutation(perm):
|
|||
return tuple(out)
|
||||
|
||||
|
||||
def _split_dyn_shape_index(x, axis):
|
||||
"""Calculate reduce prod grad invert permutation."""
|
||||
input_shape = dyn_shape_op(x)
|
||||
rank = dyn_rank(x)
|
||||
if not isinstance(axis, Tensor):
|
||||
axis = Tensor(axis, dtype=mstype.int64)
|
||||
reduction_indices = reshape(axis, (-1,))
|
||||
reduction_indices = (reduction_indices + rank) % rank
|
||||
reduced = P.Cast()(reduction_indices, mstype.int64)
|
||||
|
||||
start = Tensor(0, dtype=mstype.int64)
|
||||
delta = Tensor(1, dtype=mstype.int64)
|
||||
idx = P.Range()(start, rank, delta)
|
||||
other, _ = A.ListDiff()(idx, reduced)
|
||||
perm = P.Concat()((reduced, other))
|
||||
reduced_num = reduce_prod(P.Cast()(P.Gather()(input_shape, reduced, 0), mstype.int64), ())
|
||||
other_num = reduce_prod(P.Cast()(P.Gather()(input_shape, other, 0), mstype.int64), ())
|
||||
return (reduced_num, other_num), perm
|
||||
|
||||
|
||||
@bprop_getters.register(P.ReduceProd)
|
||||
def get_bprop_reduceprod(self):
|
||||
"""Grad definition for `ReduceProd` operation."""
|
||||
|
@ -921,15 +942,27 @@ def get_bprop_reduceprod(self):
|
|||
"""Grad definition for `Product` operation."""
|
||||
# Expand dout to full input shape
|
||||
input_shape = shape_op(x)
|
||||
output_shape_kept_dims = reduced_shape(input_shape, axis)
|
||||
if is_shape_unknown(input_shape):
|
||||
input_shape = dyn_shape_op(x)
|
||||
input_shape = P.Cast()(input_shape, ms.int64)
|
||||
output_shape_kept_dims = _dyn_reduced_shape(input_shape, axis, x)
|
||||
output_shape_kept_dims = P.Cast()(output_shape_kept_dims, ms.int64)
|
||||
else:
|
||||
output_shape_kept_dims = reduced_shape(input_shape, axis)
|
||||
|
||||
dout = reshape(dout, output_shape_kept_dims)
|
||||
tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
|
||||
grad = tile(dout, tile_scaling)
|
||||
|
||||
# Pack all reduced dimensions into a single one, so we can perform the cumprod ops.
|
||||
pack_shape, perm = _split_shape_index(input_shape, axis)
|
||||
if is_shape_unknown(shape_op(x)):
|
||||
pack_shape, perm = _split_dyn_shape_index(x, axis)
|
||||
else:
|
||||
pack_shape, perm = _split_shape_index(shape_op(x), axis)
|
||||
|
||||
permuted = transpose(x, perm)
|
||||
permuted_shape = shape_op(permuted)
|
||||
if is_shape_unknown(permuted_shape):
|
||||
permuted_shape = dyn_shape_op(permuted)
|
||||
pack_shape = create_tensor_by_element(pack_shape)
|
||||
reshaped = reshape(permuted, pack_shape)
|
||||
|
||||
# Calculate product, leaving out the current entry
|
||||
|
@ -939,7 +972,14 @@ def get_bprop_reduceprod(self):
|
|||
|
||||
# Invert the transpose and reshape operations.
|
||||
# Make sure to set the statically known shape information through a reshape.
|
||||
out = transpose(y, _invert_permutation(perm)) * grad
|
||||
if is_shape_unknown(shape_op(permuted)):
|
||||
dout = DynamicBroadcastTo()(dout, input_shape)
|
||||
out = transpose(y, dyn_invert_permutation(perm)) * dout
|
||||
else:
|
||||
tile_scaling = tuple_div(input_shape, output_shape_kept_dims)
|
||||
grad = tile(dout, tile_scaling)
|
||||
out = transpose(y, _invert_permutation(perm)) * grad
|
||||
|
||||
dx = reshape(out, input_shape)
|
||||
return dx, zeros_like(axis)
|
||||
|
||||
|
|
Loading…
Reference in New Issue