forked from mindspore-Ecosystem/mindspore
fix st error
This commit is contained in:
parent
577b6d0096
commit
3fd8337147
|
@ -30,6 +30,24 @@ class DropoutNet(nn.Cell):
|
|||
return self.relu(self.drop(x))
|
||||
|
||||
|
||||
class _Grad(nn.Cell):
|
||||
def __init__(self, grad, network):
|
||||
super().__init__()
|
||||
self.network = network
|
||||
self.grad = grad
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.grad(self.network)(*inputs)
|
||||
|
||||
|
||||
class GradOfFirstInput(_Grad):
|
||||
"""
|
||||
get grad of first input
|
||||
"""
|
||||
def __init__(self, network, sens_param=True):
|
||||
super().__init__(grad=GradOperation(sens_param=sens_param), network=network)
|
||||
|
||||
|
||||
def ge_drop_out_0_5(shape):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = DropoutNet(0.5)
|
||||
|
@ -42,7 +60,7 @@ def ge_drop_out_0_5(shape):
|
|||
def ge_dropout_backward_0_5(shape):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = DropoutNet(0.5)
|
||||
grad_net = GradOperation(sens_param=True)(net)
|
||||
grad_net = GradOfFirstInput(net)
|
||||
grad_net.set_train()
|
||||
x = Tensor(np.ones(shape).astype(np.float32))
|
||||
sens = Tensor(np.ones(shape).astype(np.float32))
|
||||
|
|
Loading…
Reference in New Issue