forked from mindspore-Ecosystem/mindspore
fix maskedselect grad definition
This commit is contained in:
parent
c564b2cc86
commit
1687d2a991
|
@ -1065,8 +1065,8 @@ def get_bprop_masked_select(self):
|
|||
"""Generate bprop for MaskedSelect"""
|
||||
op = G.MaskedSelectGrad()
|
||||
|
||||
def bprop(x, mask, dout):
|
||||
def bprop(x, mask, out, dout):
|
||||
dx = op(x, mask, dout)
|
||||
return (dx,)
|
||||
return (dx, zeros_like(mask))
|
||||
|
||||
return bprop
|
||||
|
|
|
@ -21,7 +21,6 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
def maskedselect():
|
||||
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||
|
@ -49,13 +48,20 @@ class Grad(nn.Cell):
|
|||
gout = self.grad(self.network)(x, mask, grad)
|
||||
return gout
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.MaskedSelect()
|
||||
|
||||
def construct(self, x, mask):
|
||||
return self.op(x, mask)
|
||||
|
||||
def masked_select_grad():
|
||||
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||
mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
|
||||
dy = np.array([i for i in range(8)]).astype(np.int32)
|
||||
grad = G.MaskedSelectGrad()
|
||||
return grad(Tensor(x), Tensor(mask), Tensor(dy))
|
||||
grad = Grad(Net())
|
||||
return grad(Tensor(x), Tensor(mask), Tensor(dy))[0]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue