forked from mindspore-Ecosystem/mindspore
Modify nested while testcase
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
188d1fc777
commit
c94dea6a51
|
@ -102,16 +102,41 @@ class ControlIfbyIfbyIf(nn.Cell):
|
|||
class ControlMixedWhileIf(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.assign = op.Assign()
|
||||
self.var = Parameter(initializer(1, (1), mstype.float32), name="var")
|
||||
|
||||
def construct(self, x, y, z, c2, c4):
|
||||
out = self.assign(self.var, c4)
|
||||
while x < c2:
|
||||
y = self.assign(self.var, c4)
|
||||
while y < c2 and x < c2:
|
||||
if 2 * y < c2:
|
||||
y = y + 2
|
||||
else:
|
||||
y = y + 1
|
||||
out = out + y
|
||||
z = self.assign(self.var, c4)
|
||||
while z < c2:
|
||||
z = z + 1
|
||||
out = out + z
|
||||
x = x + 1
|
||||
out = out + x
|
||||
while x < 2 * c2:
|
||||
y = self.assign(self.var, c4)
|
||||
x = x + 1
|
||||
while y < c2:
|
||||
z = self.assign(self.var, c4)
|
||||
while z < c2:
|
||||
z = z + 1
|
||||
if x < c2:
|
||||
y = y - 1
|
||||
else:
|
||||
y = y + 1
|
||||
out = out + z
|
||||
out = out + y
|
||||
out = out + x
|
||||
return out
|
||||
|
||||
def construct(self, x, y):
|
||||
y = y + 4
|
||||
while x < y:
|
||||
if 2 * x < y:
|
||||
x = x + 1
|
||||
else:
|
||||
x = x + 2
|
||||
x = x + 3
|
||||
return x
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -130,6 +155,7 @@ def test_simple_if():
|
|||
expect = input2 * 3 * 3 * 2 + input1
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -145,6 +171,7 @@ def test_simple_if_with_assign():
|
|||
expect = input_data
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -158,6 +185,7 @@ def test_if_in_if():
|
|||
expect = x + y + 3
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -175,6 +203,7 @@ def test_if_by_if_by_if():
|
|||
expect = input_data * 3 * 2 * 2 * 2
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
@ -183,7 +212,10 @@ def test_mixed_while_if():
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
x = np.array(2).astype(np.int32)
|
||||
y = np.array(14).astype(np.int32)
|
||||
z = np.array(1).astype(np.int32)
|
||||
c2 = Tensor([14], mstype.int32)
|
||||
c4 = Tensor([0], mstype.int32)
|
||||
net = ControlMixedWhileIf()
|
||||
output = net(Tensor(x), Tensor(y))
|
||||
expect = np.array(22).astype(np.int32)
|
||||
output = net(Tensor(x), Tensor(y), Tensor(z), c2, c4)
|
||||
expect = np.array(3318).astype(np.int32)
|
||||
assert np.allclose(expect, output.asnumpy(), 0.0001, 0.0001)
|
||||
|
|
Loading…
Reference in New Issue