forked from mindspore-Ecosystem/mindspore
!17911 fix maskedselect grad definition
From: @wuxuejian Reviewed-by: @liangchenghui,@c_34 Signed-off-by: @liangchenghui
This commit is contained in:
commit
152dbefa29
|
@ -1065,8 +1065,8 @@ def get_bprop_masked_select(self):
|
||||||
"""Generate bprop for MaskedSelect"""
|
"""Generate bprop for MaskedSelect"""
|
||||||
op = G.MaskedSelectGrad()
|
op = G.MaskedSelectGrad()
|
||||||
|
|
||||||
def bprop(x, mask, dout):
|
def bprop(x, mask, out, dout):
|
||||||
dx = op(x, mask, dout)
|
dx = op(x, mask, dout)
|
||||||
return (dx,)
|
return (dx, zeros_like(mask))
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
|
@ -21,7 +21,6 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
from mindspore.ops.operations import _grad_ops as G
|
|
||||||
|
|
||||||
def maskedselect():
|
def maskedselect():
|
||||||
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||||
|
@ -49,13 +48,20 @@ class Grad(nn.Cell):
|
||||||
gout = self.grad(self.network)(x, mask, grad)
|
gout = self.grad(self.network)(x, mask, grad)
|
||||||
return gout
|
return gout
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.op = P.MaskedSelect()
|
||||||
|
|
||||||
|
def construct(self, x, mask):
|
||||||
|
return self.op(x, mask)
|
||||||
|
|
||||||
def masked_select_grad():
|
def masked_select_grad():
|
||||||
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
x = np.array([1, 2, 3, 4]).astype(np.int32)
|
||||||
mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
|
mask = np.array([[0], [1], [0], [1]]).astype(np.bool)
|
||||||
dy = np.array([i for i in range(8)]).astype(np.int32)
|
dy = np.array([i for i in range(8)]).astype(np.int32)
|
||||||
grad = G.MaskedSelectGrad()
|
grad = Grad(Net())
|
||||||
return grad(Tensor(x), Tensor(mask), Tensor(dy))
|
return grad(Tensor(x), Tensor(mask), Tensor(dy))[0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
@pytest.mark.level0
|
||||||
|
|
Loading…
Reference in New Issue