!17911 fix maskedselect grad definition

From: @wuxuejian
Reviewed-by: @liangchenghui,@c_34
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-06-08 09:36:18 +08:00 committed by Gitee
commit 152dbefa29
2 changed files with 11 additions and 5 deletions

View File

@ -1065,8 +1065,8 @@ def get_bprop_masked_select(self):
"""Generate bprop for MaskedSelect""" """Generate bprop for MaskedSelect"""
op = G.MaskedSelectGrad() op = G.MaskedSelectGrad()
def bprop(x, mask, dout): def bprop(x, mask, out, dout):
dx = op(x, mask, dout) dx = op(x, mask, dout)
return (dx,) return (dx, zeros_like(mask))
return bprop return bprop

View File

@ -21,7 +21,6 @@ import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops.operations import _grad_ops as G
def maskedselect(): def maskedselect():
x = np.array([1, 2, 3, 4]).astype(np.int32) x = np.array([1, 2, 3, 4]).astype(np.int32)
@ -49,13 +48,20 @@ class Grad(nn.Cell):
gout = self.grad(self.network)(x, mask, grad) gout = self.grad(self.network)(x, mask, grad)
return gout return gout
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.op = P.MaskedSelect()
def construct(self, x, mask):
return self.op(x, mask)
def masked_select_grad(): def masked_select_grad():
x = np.array([1, 2, 3, 4]).astype(np.int32) x = np.array([1, 2, 3, 4]).astype(np.int32)
mask = np.array([[0], [1], [0], [1]]).astype(np.bool) mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
dy = np.array([i for i in range(8)]).astype(np.int32) dy = np.array([i for i in range(8)]).astype(np.int32)
grad = G.MaskedSelectGrad() grad = Grad(Net())
return grad(Tensor(x), Tensor(mask), Tensor(dy)) return grad(Tensor(x), Tensor(mask), Tensor(dy))[0]
@pytest.mark.level0 @pytest.mark.level0