fix tile grad
This commit is contained in:
parent
86e0967260
commit
b74bed82b0
|
@ -311,14 +311,9 @@ def get_bprop_tile(self):
|
|||
shape_x = shape_x[1:]
|
||||
shapey = concat((P.Ones()((1,), mstype.int64), shapey))
|
||||
shapey = shapey[1:]
|
||||
if rankx == ranky:
|
||||
tile_shape = transpose(P.Stack()((shapey, shape_x)), (1, 0))
|
||||
r_shape = P.Reshape()(tile_shape, (-1,))
|
||||
axis = get_reduce_axis(r_shape)
|
||||
else:
|
||||
tile_shape = transpose(P.Stack(1)((shapey, shape_x)), (1, 0))
|
||||
r_shape = P.Reshape()(tile_shape, (-1,))
|
||||
axis = get_reduce_axis(r_shape)
|
||||
tile_shape = P.Stack(1)((shapey, shape_x))
|
||||
r_shape = P.Reshape()(tile_shape, (-1,))
|
||||
axis = get_reduce_axis(r_shape)
|
||||
|
||||
dout_reshaped = P.Reshape()(dout, r_shape)
|
||||
dout_origin_dtype = dout_reshaped.dtype
|
||||
|
|
Loading…
Reference in New Issue