forked from mindspore-Ecosystem/mindspore
fix masked select grad vmap bug
This commit is contained in:
parent
b689657f5d
commit
d22467fa8e
|
@ -976,7 +976,8 @@ def get_masked_select_grad_vmap_rule(prim, axis_size):
|
|||
x = _bdim_at_front(x, x_dim, axis_size)
|
||||
mask = _bdim_at_front(mask, mask_dim, axis_size)
|
||||
outgrad = _bdim_at_front(outgrad, outgrad_dim, axis_size)
|
||||
outgrad = outgrad.reshape(outgrad[0] * outgrad[0])
|
||||
outgrad_shape = F.shape(outgrad)
|
||||
outgrad = F.reshape(outgrad, (outgrad_shape[0] * outgrad_shape[1],))
|
||||
x_grad = prim(x, mask, outgrad)
|
||||
return (x_grad, 0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue