fix st error

This commit is contained in:
chenmai1102 2022-05-23 10:24:46 +08:00
parent 577b6d0096
commit 3fd8337147
1 changed files with 19 additions and 1 deletions

View File

@ -30,6 +30,24 @@ class DropoutNet(nn.Cell):
return self.relu(self.drop(x))
class _Grad(nn.Cell):
def __init__(self, grad, network):
super().__init__()
self.network = network
self.grad = grad
def construct(self, *inputs):
return self.grad(self.network)(*inputs)
class GradOfFirstInput(_Grad):
"""
get grad of first input
"""
def __init__(self, network, sens_param=True):
super().__init__(grad=GradOperation(sens_param=sens_param), network=network)
def ge_drop_out_0_5(shape):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = DropoutNet(0.5)
@ -42,7 +60,7 @@ def ge_drop_out_0_5(shape):
def ge_dropout_backward_0_5(shape):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = DropoutNet(0.5)
grad_net = GradOperation(sens_param=True)(net)
grad_net = GradOfFirstInput(net)
grad_net.set_train()
x = Tensor(np.ones(shape).astype(np.float32))
sens = Tensor(np.ones(shape).astype(np.float32))