add inner control flow tests

This commit is contained in:
Margaret_wangrui 2021-04-21 10:06:48 +08:00
parent 9c6f067be9
commit e1f452c151
5 changed files with 105 additions and 10 deletions

View File

@ -20,8 +20,8 @@ from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
def test_signle_if():
class SignleIfNet(nn.Cell):
def test_single_if():
class SingleIfNet(nn.Cell):
def construct(self, x, y):
x += 1
if x < y:
@ -44,14 +44,14 @@ def test_signle_if():
# graph mode
context.set_context(mode=context.GRAPH_MODE)
if_net = SignleIfNet()
if_net = SingleIfNet()
net = GradNet(if_net)
graph_forward_res = if_net(x, y)
graph_backward_res = net(x, y)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
if_net = SignleIfNet()
if_net = SingleIfNet()
net = GradNet(if_net)
pynative_forward_res = if_net(x, y)
pynative_backward_res = net(x, y)

View File

@ -21,8 +21,8 @@ from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
def test_signle_for():
class SignleForNet(nn.Cell):
def test_single_for_01():
class SingleForNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.Add()
@ -49,14 +49,58 @@ def test_signle_for():
# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_net = SignleForNet()
for_net = SingleForNet()
net = GradNet(for_net)
graph_forward_res = for_net(x, y, z)
graph_backward_res = net(x, y, z)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_net = SignleForNet()
for_net = SingleForNet()
net = GradNet(for_net)
pynative_forward_res = for_net(x, y, z)
pynative_backward_res = net(x, y, z)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
def test_single_for_02():
class SingleForNet(nn.Cell):
def __init__(self):
super().__init__()
self.add = P.Add()
self.mul = P.Mul()
def construct(self, x, y, z):
x = self.add(x, y)
for _ in range(10, -5, -3):
z = self.add(z, x)
y = self.mul(z, y)
return y
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return grad_all(self.net)(*inputs)
x = Tensor([2], mstype.int32)
y = Tensor([5], mstype.int32)
z = Tensor([4], mstype.int32)
# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_net = SingleForNet()
net = GradNet(for_net)
graph_forward_res = for_net(x, y, z)
graph_backward_res = net(x, y, z)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_net = SingleForNet()
net = GradNet(for_net)
pynative_forward_res = for_net(x, y, z)
pynative_backward_res = net(x, y, z)

View File

@ -23,7 +23,7 @@ from mindspore.common import dtype as mstype
grad_all = C.GradOperation(get_all=True)
context.set_context(device_target="Ascend")
def test_for_in_if():
def test_for_in_if_01():
class ForInIfNet(nn.Cell):
def __init__(self):
super().__init__()
@ -68,3 +68,53 @@ def test_for_in_if():
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res
def test_for_in_if_02():
class ForInIfNet(nn.Cell):
def __init__(self):
super().__init__()
self.mul = P.Mul()
self.add = P.Add()
param_a = np.full((1,), 5, dtype=np.float32)
self.param_a = Parameter(Tensor(param_a), name='a')
param_b = np.full((1,), 4, dtype=np.float32)
self.param_b = Parameter(Tensor(param_b), name='b')
def func(self, x):
for _ in range(0, 5):
x = self.add(x, x)
self.param_b += 1
return x
def construct(self, x):
if self.param_a > self.func(x):
x = self.mul(x, 2)
return x
return x
class GradNet(nn.Cell):
def __init__(self, net):
super(GradNet, self).__init__()
self.net = net
def construct(self, *inputs):
return grad_all(self.net)(*inputs)
x = Tensor([10], mstype.int32)
# graph mode
context.set_context(mode=context.GRAPH_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
graph_forward_res = for_in_if_net(x)
graph_backward_res = net(x)
# pynative mode
context.set_context(mode=context.PYNATIVE_MODE)
for_in_if_net = ForInIfNet()
net = GradNet(for_in_if_net)
pynative_forward_res = for_in_if_net(x)
pynative_backward_res = net(x)
assert graph_forward_res == pynative_forward_res
assert graph_backward_res == pynative_backward_res

View File

@ -16,6 +16,7 @@ import numpy as np
from mindspore import context
from mindspore import Tensor, nn
from mindspore.common.parameter import Parameter
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype

View File

@ -45,7 +45,7 @@ def test_for_after_while_in_if():
self.assign(self.param_a, x + self.param_a)
y = self.add(y, self.param_b)
if self.param_b == y - self.param_a:
if self.param_b != y - self.param_a:
self.param_c = self.div(self.param_c, self.param_b)
while self.param_a > x:
self.param_c = self.param_a + 2