From b80c289caa3687ece08d85fbafae318e7fbbcd9a Mon Sep 17 00:00:00 2001 From: meizhiyuan Date: Tue, 24 Nov 2020 18:48:20 +0800 Subject: [PATCH] control_flow case --- tests/st/control/test_if_by_if.py | 68 +++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 tests/st/control/test_if_by_if.py diff --git a/tests/st/control/test_if_by_if.py b/tests/st/control/test_if_by_if.py new file mode 100644 index 00000000000..41c143c1d92 --- /dev/null +++ b/tests/st/control/test_if_by_if.py @@ -0,0 +1,68 @@ +import numpy as np +import pytest +import mindspore.context as context +from mindspore import Tensor +from mindspore.common.parameter import Parameter +from mindspore.nn import Cell +import mindspore.ops.operations as P + + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_if_by_if_basic(): + class SubNet(Cell): + def __init__(self): + super().__init__() + self.mul = P.Mul() + self.add = P.TensorAdd() + a = np.full((1,), 5, dtype=np.float32) + self.a = Parameter(Tensor(a), name='a') + b = np.full((1,), 4, dtype=np.float32) + self.b = Parameter(Tensor(b), name='b') + + def construct(self, x): + if self.a > self.b: + x = self.mul(x, 1) + while self.b < 6: + x = self.add(x, x) + self.b += 1 + return x + + class Net(Cell): + def __init__(self): + super().__init__() + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + self.subnet = SubNet() + self.relu = P.ReLU() + self.add = P.TensorAdd() + a = np.full((1,), 5, dtype=np.float32) + self.a = Parameter(Tensor(a), name='a') + b = np.full((1,), 4, dtype=np.float32) + self.b = Parameter(Tensor(b), name='b') + c = np.full((1,), 7, dtype=np.float32) + self.c = Parameter(Tensor(c), name='c') + + def func(self, x): + for _ in range(0, 2): + x = self.add(x, 0) + return x + + def construct(self, x): + if self.a > self.b: + x = self.subnet(x) + else: + x = self.relu(x) + if self.a < self.c: + x = self.func(x) + else: + x = self.add(x, 2) + return x + + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + net = Net() + out_ms = net(Tensor(input_np)) + out_np = input_np * 4 + assert np.allclose(out_ms.asnumpy(), out_np)