From e1f452c151a60b8be376114df6ed81324254e640 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Wed, 21 Apr 2021 10:06:48 +0800 Subject: [PATCH] add inner control flow tests --- tests/st/control/inner/test_000_single_if.py | 8 +-- tests/st/control/inner/test_002_single_for.py | 52 +++++++++++++++++-- tests/st/control/inner/test_030_for_in_if.py | 52 ++++++++++++++++++- tests/st/control/inner/test_032_for_in_for.py | 1 + .../inner/test_320_for_after_while_in_if.py | 2 +- 5 files changed, 105 insertions(+), 10 deletions(-) diff --git a/tests/st/control/inner/test_000_single_if.py b/tests/st/control/inner/test_000_single_if.py index 3b09b56ceb7..6d022ddb435 100644 --- a/tests/st/control/inner/test_000_single_if.py +++ b/tests/st/control/inner/test_000_single_if.py @@ -20,8 +20,8 @@ from mindspore.common import dtype as mstype grad_all = C.GradOperation(get_all=True) context.set_context(device_target="Ascend") -def test_signle_if(): - class SignleIfNet(nn.Cell): +def test_single_if(): + class SingleIfNet(nn.Cell): def construct(self, x, y): x += 1 if x < y: @@ -44,14 +44,14 @@ def test_signle_if(): # graph mode context.set_context(mode=context.GRAPH_MODE) - if_net = SignleIfNet() + if_net = SingleIfNet() net = GradNet(if_net) graph_forward_res = if_net(x, y) graph_backward_res = net(x, y) # pynative mode context.set_context(mode=context.PYNATIVE_MODE) - if_net = SignleIfNet() + if_net = SingleIfNet() net = GradNet(if_net) pynative_forward_res = if_net(x, y) pynative_backward_res = net(x, y) diff --git a/tests/st/control/inner/test_002_single_for.py b/tests/st/control/inner/test_002_single_for.py index ad2795dc787..b357058089c 100644 --- a/tests/st/control/inner/test_002_single_for.py +++ b/tests/st/control/inner/test_002_single_for.py @@ -21,8 +21,8 @@ from mindspore.common import dtype as mstype grad_all = C.GradOperation(get_all=True) context.set_context(device_target="Ascend") -def test_signle_for(): - class SignleForNet(nn.Cell): +def test_single_for_01(): + class SingleForNet(nn.Cell): def __init__(self): super().__init__() self.add = P.Add() @@ -49,14 +49,58 @@ def test_signle_for(): # graph mode context.set_context(mode=context.GRAPH_MODE) - for_net = SignleForNet() + for_net = SingleForNet() net = GradNet(for_net) graph_forward_res = for_net(x, y, z) graph_backward_res = net(x, y, z) # pynative mode context.set_context(mode=context.PYNATIVE_MODE) - for_net = SignleForNet() + for_net = SingleForNet() + net = GradNet(for_net) + pynative_forward_res = for_net(x, y, z) + pynative_backward_res = net(x, y, z) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res + + +def test_single_for_02(): + class SingleForNet(nn.Cell): + def __init__(self): + super().__init__() + self.add = P.Add() + self.mul = P.Mul() + + def construct(self, x, y, z): + x = self.add(x, y) + for _ in range(10, -5, -3): + z = self.add(z, x) + y = self.mul(z, y) + return y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor([2], mstype.int32) + y = Tensor([5], mstype.int32) + z = Tensor([4], mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + for_net = SingleForNet() + net = GradNet(for_net) + graph_forward_res = for_net(x, y, z) + graph_backward_res = net(x, y, z) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + for_net = SingleForNet() net = GradNet(for_net) pynative_forward_res = for_net(x, y, z) pynative_backward_res = net(x, y, z) diff --git a/tests/st/control/inner/test_030_for_in_if.py b/tests/st/control/inner/test_030_for_in_if.py index 5b031c00992..721f200ed1a 100644 --- a/tests/st/control/inner/test_030_for_in_if.py +++ b/tests/st/control/inner/test_030_for_in_if.py @@ -23,7 +23,7 @@ from mindspore.common import dtype as mstype grad_all = C.GradOperation(get_all=True) context.set_context(device_target="Ascend") -def test_for_in_if(): +def test_for_in_if_01(): class ForInIfNet(nn.Cell): def __init__(self): super().__init__() @@ -68,3 +68,53 @@ def test_for_in_if(): assert graph_forward_res == pynative_forward_res assert graph_backward_res == pynative_backward_res + +def test_for_in_if_02(): + class ForInIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.mul = P.Mul() + self.add = P.Add() + param_a = np.full((1,), 5, dtype=np.float32) + self.param_a = Parameter(Tensor(param_a), name='a') + param_b = np.full((1,), 4, dtype=np.float32) + self.param_b = Parameter(Tensor(param_b), name='b') + + def func(self, x): + for _ in range(0, 5): + x = self.add(x, x) + self.param_b += 1 + return x + + def construct(self, x): + if self.param_a > self.func(x): + x = self.mul(x, 2) + return x + return x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor([10], mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + for_in_if_net = ForInIfNet() + net = GradNet(for_in_if_net) + graph_forward_res = for_in_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + for_in_if_net = ForInIfNet() + net = GradNet(for_in_if_net) + pynative_forward_res = for_in_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_032_for_in_for.py b/tests/st/control/inner/test_032_for_in_for.py index c717bd22742..421753a44c3 100644 --- a/tests/st/control/inner/test_032_for_in_for.py +++ b/tests/st/control/inner/test_032_for_in_for.py @@ -16,6 +16,7 @@ import numpy as np from mindspore import context from mindspore import Tensor, nn from mindspore.common.parameter import Parameter +from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.common import dtype as mstype diff --git a/tests/st/control/inner/test_320_for_after_while_in_if.py b/tests/st/control/inner/test_320_for_after_while_in_if.py index dd24811ed0f..0cdebb97057 100644 --- a/tests/st/control/inner/test_320_for_after_while_in_if.py +++ b/tests/st/control/inner/test_320_for_after_while_in_if.py @@ -45,7 +45,7 @@ def test_for_after_while_in_if(): self.assign(self.param_a, x + self.param_a) y = self.add(y, self.param_b) - if self.param_b == y - self.param_a: + if self.param_b != y - self.param_a: self.param_c = self.div(self.param_c, self.param_b) while self.param_a > x: self.param_c = self.param_a + 2