diff --git a/mindspore/python/mindspore/ops/_grad/grad_array_ops.py b/mindspore/python/mindspore/ops/_grad/grad_array_ops.py index d8ee604e321..f34e2b786f9 100644 --- a/mindspore/python/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/python/mindspore/ops/_grad/grad_array_ops.py @@ -311,14 +311,9 @@ def get_bprop_tile(self): shape_x = shape_x[1:] shapey = concat((P.Ones()((1,), mstype.int64), shapey)) shapey = shapey[1:] - if rankx == ranky: - tile_shape = transpose(P.Stack()((shapey, shape_x)), (1, 0)) - r_shape = P.Reshape()(tile_shape, (-1,)) - axis = get_reduce_axis(r_shape) - else: - tile_shape = transpose(P.Stack(1)((shapey, shape_x)), (1, 0)) - r_shape = P.Reshape()(tile_shape, (-1,)) - axis = get_reduce_axis(r_shape) + tile_shape = P.Stack(1)((shapey, shape_x)) + r_shape = P.Reshape()(tile_shape, (-1,)) + axis = get_reduce_axis(r_shape) dout_reshaped = P.Reshape()(dout, r_shape) dout_origin_dtype = dout_reshaped.dtype