fix tile grad

This commit is contained in:
r1chardf1d0 2022-11-03 12:06:55 +08:00
parent 86e0967260
commit b74bed82b0
1 changed files with 3 additions and 8 deletions

View File

@ -311,14 +311,9 @@ def get_bprop_tile(self):
shape_x = shape_x[1:]
shapey = concat((P.Ones()((1,), mstype.int64), shapey))
shapey = shapey[1:]
if rankx == ranky:
tile_shape = transpose(P.Stack()((shapey, shape_x)), (1, 0))
r_shape = P.Reshape()(tile_shape, (-1,))
axis = get_reduce_axis(r_shape)
else:
tile_shape = transpose(P.Stack(1)((shapey, shape_x)), (1, 0))
r_shape = P.Reshape()(tile_shape, (-1,))
axis = get_reduce_axis(r_shape)
tile_shape = P.Stack(1)((shapey, shape_x))
r_shape = P.Reshape()(tile_shape, (-1,))
axis = get_reduce_axis(r_shape)
dout_reshaped = P.Reshape()(dout, r_shape)
dout_origin_dtype = dout_reshaped.dtype