forked from mindspore-Ecosystem/mindspore
add inner control flow tests
This commit is contained in:
parent
9c6f067be9
commit
e1f452c151
|
@ -20,8 +20,8 @@ from mindspore.common import dtype as mstype
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_signle_if():
|
||||
class SignleIfNet(nn.Cell):
|
||||
def test_single_if():
|
||||
class SingleIfNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
x += 1
|
||||
if x < y:
|
||||
|
@ -44,14 +44,14 @@ def test_signle_if():
|
|||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
if_net = SignleIfNet()
|
||||
if_net = SingleIfNet()
|
||||
net = GradNet(if_net)
|
||||
graph_forward_res = if_net(x, y)
|
||||
graph_backward_res = net(x, y)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
if_net = SignleIfNet()
|
||||
if_net = SingleIfNet()
|
||||
net = GradNet(if_net)
|
||||
pynative_forward_res = if_net(x, y)
|
||||
pynative_backward_res = net(x, y)
|
||||
|
|
|
@ -21,8 +21,8 @@ from mindspore.common import dtype as mstype
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_signle_for():
|
||||
class SignleForNet(nn.Cell):
|
||||
def test_single_for_01():
|
||||
class SingleForNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.Add()
|
||||
|
@ -49,14 +49,58 @@ def test_signle_for():
|
|||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_net = SignleForNet()
|
||||
for_net = SingleForNet()
|
||||
net = GradNet(for_net)
|
||||
graph_forward_res = for_net(x, y, z)
|
||||
graph_backward_res = net(x, y, z)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for_net = SignleForNet()
|
||||
for_net = SingleForNet()
|
||||
net = GradNet(for_net)
|
||||
pynative_forward_res = for_net(x, y, z)
|
||||
pynative_backward_res = net(x, y, z)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
|
||||
def test_single_for_02():
|
||||
class SingleForNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add = P.Add()
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
x = self.add(x, y)
|
||||
for _ in range(10, -5, -3):
|
||||
z = self.add(z, x)
|
||||
y = self.mul(z, y)
|
||||
return y
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor([2], mstype.int32)
|
||||
y = Tensor([5], mstype.int32)
|
||||
z = Tensor([4], mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_net = SingleForNet()
|
||||
net = GradNet(for_net)
|
||||
graph_forward_res = for_net(x, y, z)
|
||||
graph_backward_res = net(x, y, z)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for_net = SingleForNet()
|
||||
net = GradNet(for_net)
|
||||
pynative_forward_res = for_net(x, y, z)
|
||||
pynative_backward_res = net(x, y, z)
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore.common import dtype as mstype
|
|||
grad_all = C.GradOperation(get_all=True)
|
||||
context.set_context(device_target="Ascend")
|
||||
|
||||
def test_for_in_if():
|
||||
def test_for_in_if_01():
|
||||
class ForInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -68,3 +68,53 @@ def test_for_in_if():
|
|||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
||||
def test_for_in_if_02():
|
||||
class ForInIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mul = P.Mul()
|
||||
self.add = P.Add()
|
||||
param_a = np.full((1,), 5, dtype=np.float32)
|
||||
self.param_a = Parameter(Tensor(param_a), name='a')
|
||||
param_b = np.full((1,), 4, dtype=np.float32)
|
||||
self.param_b = Parameter(Tensor(param_b), name='b')
|
||||
|
||||
def func(self, x):
|
||||
for _ in range(0, 5):
|
||||
x = self.add(x, x)
|
||||
self.param_b += 1
|
||||
return x
|
||||
|
||||
def construct(self, x):
|
||||
if self.param_a > self.func(x):
|
||||
x = self.mul(x, 2)
|
||||
return x
|
||||
return x
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
|
||||
def construct(self, *inputs):
|
||||
return grad_all(self.net)(*inputs)
|
||||
|
||||
x = Tensor([10], mstype.int32)
|
||||
|
||||
# graph mode
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
for_in_if_net = ForInIfNet()
|
||||
net = GradNet(for_in_if_net)
|
||||
graph_forward_res = for_in_if_net(x)
|
||||
graph_backward_res = net(x)
|
||||
|
||||
# pynative mode
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
for_in_if_net = ForInIfNet()
|
||||
net = GradNet(for_in_if_net)
|
||||
pynative_forward_res = for_in_if_net(x)
|
||||
pynative_backward_res = net(x)
|
||||
|
||||
assert graph_forward_res == pynative_forward_res
|
||||
assert graph_backward_res == pynative_backward_res
|
||||
|
|
|
@ -16,6 +16,7 @@ import numpy as np
|
|||
from mindspore import context
|
||||
from mindspore import Tensor, nn
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ def test_for_after_while_in_if():
|
|||
self.assign(self.param_a, x + self.param_a)
|
||||
y = self.add(y, self.param_b)
|
||||
|
||||
if self.param_b == y - self.param_a:
|
||||
if self.param_b != y - self.param_a:
|
||||
self.param_c = self.div(self.param_c, self.param_b)
|
||||
while self.param_a > x:
|
||||
self.param_c = self.param_a + 2
|
||||
|
|
Loading…
Reference in New Issue