fix masked select grad vmap bug

This commit is contained in:
zhaodezan 2022-07-26 15:02:43 +08:00
parent b689657f5d
commit d22467fa8e
1 changed files with 2 additions and 1 deletions

View File

@ -976,7 +976,8 @@ def get_masked_select_grad_vmap_rule(prim, axis_size):
x = _bdim_at_front(x, x_dim, axis_size)
mask = _bdim_at_front(mask, mask_dim, axis_size)
outgrad = _bdim_at_front(outgrad, outgrad_dim, axis_size)
outgrad = outgrad.reshape(outgrad[0] * outgrad[0])
outgrad_shape = F.shape(outgrad)
outgrad = F.reshape(outgrad, (outgrad_shape[0] * outgrad_shape[1],))
x_grad = prim(x, mask, outgrad)
return (x_grad, 0)