forked from mindspore-Ecosystem/mindspore
fix masked select grad vmap bug
This commit is contained in:
parent
b689657f5d
commit
d22467fa8e
|
@ -976,7 +976,8 @@ def get_masked_select_grad_vmap_rule(prim, axis_size):
|
||||||
x = _bdim_at_front(x, x_dim, axis_size)
|
x = _bdim_at_front(x, x_dim, axis_size)
|
||||||
mask = _bdim_at_front(mask, mask_dim, axis_size)
|
mask = _bdim_at_front(mask, mask_dim, axis_size)
|
||||||
outgrad = _bdim_at_front(outgrad, outgrad_dim, axis_size)
|
outgrad = _bdim_at_front(outgrad, outgrad_dim, axis_size)
|
||||||
outgrad = outgrad.reshape(outgrad[0] * outgrad[0])
|
outgrad_shape = F.shape(outgrad)
|
||||||
|
outgrad = F.reshape(outgrad, (outgrad_shape[0] * outgrad_shape[1],))
|
||||||
x_grad = prim(x, mask, outgrad)
|
x_grad = prim(x, mask, outgrad)
|
||||||
return (x_grad, 0)
|
return (x_grad, 0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue