!17911 fix maskedselect grad definition

From: @wuxuejian
Reviewed-by: @liangchenghui,@c_34
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-06-08 09:36:18 +08:00 committed by Gitee
commit 152dbefa29
2 changed files with 11 additions and 5 deletions

View File

@ -1065,8 +1065,8 @@ def get_bprop_masked_select(self):
"""Generate bprop for MaskedSelect"""
op = G.MaskedSelectGrad()
def bprop(x, mask, dout):
def bprop(x, mask, out, dout):
dx = op(x, mask, dout)
return (dx,)
return (dx, zeros_like(mask))
return bprop

View File

@ -21,7 +21,6 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations import _grad_ops as G
def maskedselect():
x = np.array([1, 2, 3, 4]).astype(np.int32)
@ -49,13 +48,20 @@ class Grad(nn.Cell):
gout = self.grad(self.network)(x, mask, grad)
return gout
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = P.MaskedSelect()
def construct(self, x, mask):
return self.op(x, mask)
def masked_select_grad():
x = np.array([1, 2, 3, 4]).astype(np.int32)
mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
dy = np.array([i for i in range(8)]).astype(np.int32)
grad = G.MaskedSelectGrad()
return grad(Tensor(x), Tensor(mask), Tensor(dy))
grad = Grad(Net())
return grad(Tensor(x), Tensor(mask), Tensor(dy))[0]
@pytest.mark.level0