don't slice shape if shape size<=1 of sparsegatherv2 grad

This commit is contained in:
chenfei 2020-12-24 09:33:25 +08:00
parent bceb03a07e
commit ccb5bce359
1 changed files with 4 additions and 1 deletions

View File

@ -433,6 +433,9 @@ def get_bprop_sparse_gather_v2(self):
x_shp = shape_op(x)
if axis == 0:
indices_size = (size_op(indices),)
if len(x_shp) <= 1:
x_tail_shp = ()
else:
x_tail_shp = x_shp[1:]
values_shape = indices_size + x_tail_shp
values = reshape(dout, values_shape)