fix test cases
This commit is contained in:
parent
7d4f481884
commit
113641decd
|
@ -26,6 +26,9 @@ from mindspore.ops import operations as P
|
|||
# from tests.vm_impl.vm_interface import *
|
||||
# from tests.vm_impl import *
|
||||
|
||||
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
|
||||
grad_all = C.GradOperation('get_all', get_all=True)
|
||||
|
||||
|
||||
def setup_module():
|
||||
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
|
||||
|
@ -86,7 +89,7 @@ def test_while_opt_endless():
|
|||
|
||||
@ms_function
|
||||
def construct(self, *inputs):
|
||||
return C.grad_all(self.net)(*inputs)
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -149,7 +152,7 @@ def test_while_with_param_grad_with_const_branch():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -189,7 +192,7 @@ def test_for_while_with_param_grad_with_const_branch():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -226,7 +229,7 @@ def test_for_while_with_param_grad_basic():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -263,7 +266,7 @@ def test_for_while_with_param_grad_normal():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -297,7 +300,7 @@ def test_while_with_param_basic_grad():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -331,7 +334,7 @@ def test_while_with_param_basic_grad_mul():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -366,7 +369,7 @@ def test_while_with_param_basic_grad_two():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -402,7 +405,7 @@ def test_while_with_param_basic_grad_three():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -439,7 +442,7 @@ def test_while_if_with_param_grad():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -472,7 +475,7 @@ def test_while_with_param_grad_not_enter_while():
|
|||
|
||||
@ms_function
|
||||
def construct(self, a, b, c):
|
||||
return C.grad_by_list(self.net, self.weights)(a, b, c)
|
||||
return grad_by_list(self.net, self.weights)(a, b, c)
|
||||
|
||||
while_net = MyWhileNet()
|
||||
net = GradNet(while_net)
|
||||
|
@ -534,7 +537,7 @@ def test_with_param_if_by_if_grad_inputs():
|
|||
|
||||
@ms_function
|
||||
def construct(self, *inputs):
|
||||
return C.grad_all(self.net)(*inputs)
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
|
@ -568,7 +571,7 @@ def test_with_param_if_by_if_grad_parameter():
|
|||
|
||||
@ms_function
|
||||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
return grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
|
@ -600,7 +603,7 @@ def test_with_param_if_by_if_grad_param_excute_null():
|
|||
|
||||
@ms_function
|
||||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
return grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
|
@ -634,7 +637,7 @@ def test_if_by_if_return_inside_grad():
|
|||
|
||||
@ms_function
|
||||
def construct(self, *inputs):
|
||||
return C.grad_by_list(self.net, self.weights)(*inputs)
|
||||
return grad_by_list(self.net, self.weights)(*inputs)
|
||||
|
||||
if_net = MyIfByIfNet()
|
||||
net = GradNet(if_net)
|
||||
|
|
Loading…
Reference in New Issue