forked from mindspore-Ecosystem/mindspore
!10420 [ME]SparseGatherV2 throw exception when grad with 1D tensor input
From: @chenfei52 Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
f7eda1118c
|
@ -433,7 +433,10 @@ def get_bprop_sparse_gather_v2(self):
|
|||
x_shp = shape_op(x)
|
||||
if axis == 0:
|
||||
indices_size = (size_op(indices),)
|
||||
x_tail_shp = x_shp[1:]
|
||||
if len(x_shp) <= 1:
|
||||
x_tail_shp = ()
|
||||
else:
|
||||
x_tail_shp = x_shp[1:]
|
||||
values_shape = indices_size + x_tail_shp
|
||||
values = reshape(dout, values_shape)
|
||||
indices_new = reshape(indices, indices_size)
|
||||
|
|
Loading…
Reference in New Issue