Modify nested while testcase

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2020-07-01 21:50:33 +08:00
parent 188d1fc777
commit c94dea6a51
1 changed files with 43 additions and 11 deletions

View File

@ -102,16 +102,41 @@ class ControlIfbyIfbyIf(nn.Cell):
class ControlMixedWhileIf(nn.Cell):
def __init__(self):
super().__init__()
self.assign = op.Assign()
self.var = Parameter(initializer(1, (1), mstype.float32), name="var")
def construct(self, x, y, z, c2, c4):
out = self.assign(self.var, c4)
while x < c2:
y = self.assign(self.var, c4)
while y < c2 and x < c2:
if 2 * y < c2:
y = y + 2
else:
y = y + 1
out = out + y
z = self.assign(self.var, c4)
while z < c2:
z = z + 1
out = out + z
x = x + 1
out = out + x
while x < 2 * c2:
y = self.assign(self.var, c4)
x = x + 1
while y < c2:
z = self.assign(self.var, c4)
while z < c2:
z = z + 1
if x < c2:
y = y - 1
else:
y = y + 1
out = out + z
out = out + y
out = out + x
return out
def construct(self, x, y):
y = y + 4
while x < y:
if 2 * x < y:
x = x + 1
else:
x = x + 2
x = x + 3
return x
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@ -130,6 +155,7 @@ def test_simple_if():
expect = input2 * 3 * 3 * 2 + input1
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -145,6 +171,7 @@ def test_simple_if_with_assign():
expect = input_data
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -158,6 +185,7 @@ def test_if_in_if():
expect = x + y + 3
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -175,6 +203,7 @@ def test_if_by_if_by_if():
expect = input_data * 3 * 2 * 2 * 2
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@ -183,7 +212,10 @@ def test_mixed_while_if():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
x = np.array(2).astype(np.int32)
y = np.array(14).astype(np.int32)
z = np.array(1).astype(np.int32)
c2 = Tensor([14], mstype.int32)
c4 = Tensor([0], mstype.int32)
net = ControlMixedWhileIf()
output = net(Tensor(x), Tensor(y))
expect = np.array(22).astype(np.int32)
output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
expect = np.array(3318).astype(np.int32)
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)