diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc index 5474bb5c1e1..b98df08a8e1 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.cc +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -103,8 +103,8 @@ static bool isTraversable(const AnfNodePtr &node) { return false; } -static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, - const SubstitutionPtr &substitution) { +static AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, + const SubstitutionPtr &substitution) { auto manager = optimizer->manager(); bool is_match = substitution->predicate_(node); if (is_match) { @@ -126,8 +126,8 @@ static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNod return nullptr; } -static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, - std::deque *todo, bool change, size_t seen) { +static void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, std::deque *todo, + bool change, size_t seen) { if (IsValueNode(node)) { (*todo).emplace_back(GetValueNode(node)->output()); } @@ -238,6 +238,23 @@ bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, cons return changes; } +void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_map> &status, + const OptimizerPtr &optimizer, size_t space) const { + std::stringstream ss; + ss << std::endl + << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name + << std::endl; + for (size_t i = 0; i < list_.size(); i++) { + auto name = list_[i]->name_; + ss << std::left << std::setw(space + 4) << name << "\t"; + for (auto change : status.at(name + std::to_string(i))) { + ss << change << " "; + } + ss << std::endl; + } + MS_LOG(DEBUG) << ss.str(); +} + bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { // Add for substitution status counting size_t space = 0; @@ -282,19 +299,7 @@ bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, con // Display the status of each substitution if (optimizer->is_on_debug_) { - std::stringstream ss; - ss << std::endl - << "Pass: " << optimizer->name() << "(" << optimizer->CurPass_.counter << ")_" << optimizer->CurPass_.name - << std::endl; - for (size_t i = 0; i < list_.size(); i++) { - auto name = list_[i]->name_; - ss << std::left << std::setw(space + 4) << name << "\t"; - for (auto change : status[name + std::to_string(i)]) { - ss << change << " "; - } - ss << std::endl; - } - MS_LOG(DEBUG) << ss.str(); + DisplayStatusOfSubstitution(status, optimizer, space); } return changes; } diff --git a/mindspore/ccsrc/frontend/optimizer/opt.h b/mindspore/ccsrc/frontend/optimizer/opt.h index 01f21d5df65..74711b4583a 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.h +++ b/mindspore/ccsrc/frontend/optimizer/opt.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "ir/anf.h" #include "ir/func_graph.h" @@ -74,6 +75,8 @@ class SubstitutionList { bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const; bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; + void DisplayStatusOfSubstitution(const std::unordered_map> &status, + const OptimizerPtr &optimizer, size_t space) const; std::vector list_; // a flag to mark this list of Substitution can only be executed only once diff --git a/tests/st/control/inner/test_000_single_if.py b/tests/st/control/inner/test_000_single_if.py new file mode 100644 index 00000000000..3b09b56ceb7 --- /dev/null +++ b/tests/st/control/inner/test_000_single_if.py @@ -0,0 +1,60 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_signle_if(): + class SignleIfNet(nn.Cell): + def construct(self, x, y): + x += 1 + if x < y: + y += x + else: + y -= x + y += 5 + return y + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + y = Tensor(5, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + if_net = SignleIfNet() + net = GradNet(if_net) + graph_forward_res = if_net(x, y) + graph_backward_res = net(x, y) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + if_net = SignleIfNet() + net = GradNet(if_net) + pynative_forward_res = if_net(x, y) + pynative_backward_res = net(x, y) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_010_if_in_if.py b/tests/st/control/inner/test_010_if_in_if.py new file mode 100644 index 00000000000..0d178341531 --- /dev/null +++ b/tests/st/control/inner/test_010_if_in_if.py @@ -0,0 +1,64 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_if_in_if(): + class IfInIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + if self.param_a > self.param_b: + x += 10 + if x > self.param_a: + self.param_b += 1 + x += self.param_a + return x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + if_in_if_net = IfInIfNet() + net = GradNet(if_in_if_net) + graph_forward_res = if_in_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + if_in_if_net = IfInIfNet() + net = GradNet(if_in_if_net) + pynative_forward_res = if_in_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_100_if_after_if.py b/tests/st/control/inner/test_100_if_after_if.py new file mode 100644 index 00000000000..dadd933aec9 --- /dev/null +++ b/tests/st/control/inner/test_100_if_after_if.py @@ -0,0 +1,65 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_if_after_if(): + class IfAfterIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + out = x + self.param_b + if self.param_a > self.param_b: + x += 5 + self.param_b += 4 + if x < self.param_b: + out += self.param_b + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + if_after_if_net = IfAfterIfNet() + net = GradNet(if_after_if_net) + graph_forward_res = if_after_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + if_after_if_net = IfAfterIfNet() + net = GradNet(if_after_if_net) + pynative_forward_res = if_after_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_110_if_after_if_in_if.py b/tests/st/control/inner/test_110_if_after_if_in_if.py new file mode 100644 index 00000000000..7e1b6642222 --- /dev/null +++ b/tests/st/control/inner/test_110_if_after_if_in_if.py @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_if_after_if_in_if(): + class IfAfterIfInIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + out = x + self.param_b + if self.param_a > self.param_b: + x += 5 + if x > self.param_a: + self.param_b += 1 + self.param_b += 3 + if x < self.param_b: + out += self.param_b + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + if_after_if_in_if_net = IfAfterIfInIfNet() + net = GradNet(if_after_if_in_if_net) + graph_forward_res = if_after_if_in_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + if_after_if_in_if_net = IfAfterIfInIfNet() + net = GradNet(if_after_if_in_if_net) + pynative_forward_res = if_after_if_in_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_112_if_after_if_in_for.py b/tests/st/control/inner/test_112_if_after_if_in_for.py new file mode 100644 index 00000000000..cddab1ce34c --- /dev/null +++ b/tests/st/control/inner/test_112_if_after_if_in_for.py @@ -0,0 +1,66 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_if_after_if_in_for(): + class IfAfterIfInForNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + out = x + self.param_b + for _ in range(4): + if out <= 20: + out += self.param_a + self.param_b += 3 + if x < self.param_b: + out -= self.param_b + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + if_after_if_in_for_net = IfAfterIfInForNet() + net = GradNet(if_after_if_in_for_net) + graph_forward_res = if_after_if_in_for_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + if_after_if_in_for_net = IfAfterIfInForNet() + net = GradNet(if_after_if_in_for_net) + pynative_forward_res = if_after_if_in_for_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_130_if_after_for_in_if.py b/tests/st/control/inner/test_130_if_after_for_in_if.py new file mode 100644 index 00000000000..9adb67d7ee5 --- /dev/null +++ b/tests/st/control/inner/test_130_if_after_for_in_if.py @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_if_after_for_in_if(): + class IfAfterForInIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + out = x + self.param_a + if self.param_a > self.param_b: + for _ in range(4): + self.param_a += 1 + self.param_b -= 3 + self.param_b += 15 + if x < self.param_b: + out -= self.param_b + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + if_after_for_in_if_net = IfAfterForInIfNet() + net = GradNet(if_after_for_in_if_net) + graph_forward_res = if_after_for_in_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + if_after_for_in_if_net = IfAfterForInIfNet() + net = GradNet(if_after_for_in_if_net) + pynative_forward_res = if_after_for_in_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_131_if_after_for_in_while.py b/tests/st/control/inner/test_131_if_after_for_in_while.py new file mode 100644 index 00000000000..7bb07615a8a --- /dev/null +++ b/tests/st/control/inner/test_131_if_after_for_in_while.py @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_if_after_for_in_while(): + class IfAfterForInWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(2, mstype.int32), name='b') + + def construct(self, x): + out = x + self.param_a + while self.param_a > self.param_b: + self.param_b += 1 + for _ in range(4): + self.param_a += 3 + self.param_a -= 40 + if x > self.param_a: + out += self.param_a * 10 + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + if_after_for_in_while_net = IfAfterForInWhileNet() + net = GradNet(if_after_for_in_while_net) + graph_forward_res = if_after_for_in_while_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + if_after_for_in_while_net = IfAfterForInWhileNet() + net = GradNet(if_after_for_in_while_net) + pynative_forward_res = if_after_for_in_while_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_132_if_after_for_in_for.py b/tests/st/control/inner/test_132_if_after_for_in_for.py new file mode 100644 index 00000000000..7e178a891c7 --- /dev/null +++ b/tests/st/control/inner/test_132_if_after_for_in_for.py @@ -0,0 +1,67 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_if_after_for_in_for(): + class IfAfterForInForNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(2, mstype.int32), name='b') + + def construct(self, x): + out = x + self.param_a + for _ in range(0, 10): + x *= 2 + for _ in range(0, 5): + self.param_a += 1 + x += self.param_b + if self.param_a > self.param_b: + out += x + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + if_after_for_in_for_net = IfAfterForInForNet() + net = GradNet(if_after_for_in_for_net) + graph_forward_res = if_after_for_in_for_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + if_after_for_in_for_net = IfAfterForInForNet() + net = GradNet(if_after_for_in_for_net) + pynative_forward_res = if_after_for_in_for_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_300_for_after_if.py b/tests/st/control/inner/test_300_for_after_if.py new file mode 100644 index 00000000000..9001a62be76 --- /dev/null +++ b/tests/st/control/inner/test_300_for_after_if.py @@ -0,0 +1,66 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_for_after_if(): + class ForAfterIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + out = self.param_a + if self.param_a > self.param_b: + x += 3 + self.param_b += 1 + for _ in range(0, 5): + x += self.param_b + out *= x + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(2, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + for_after_if_net = ForAfterIfNet() + net = GradNet(for_after_if_net) + graph_forward_res = for_after_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + for_after_if_net = ForAfterIfNet() + net = GradNet(for_after_if_net) + pynative_forward_res = for_after_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_310_for_after_if_in_if.py b/tests/st/control/inner/test_310_for_after_if_in_if.py new file mode 100644 index 00000000000..78d70db1138 --- /dev/null +++ b/tests/st/control/inner/test_310_for_after_if_in_if.py @@ -0,0 +1,69 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_for_after_if_in_if(): + class ForAfterIfInIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + out = self.param_a + if self.param_a > self.param_b: + x += 3 + if x > self.param_a: + self.param_b += 4 + x += self.param_a + self.param_b += 2 + for _ in range(0, 5): + x += self.param_b + out *= x + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(5, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + for_after_if_in_if_net = ForAfterIfInIfNet() + net = GradNet(for_after_if_in_if_net) + graph_forward_res = for_after_if_in_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + for_after_if_in_if_net = ForAfterIfInIfNet() + net = GradNet(for_after_if_in_if_net) + pynative_forward_res = for_after_if_in_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_330_for_after_for_in_if.py b/tests/st/control/inner/test_330_for_after_for_in_if.py new file mode 100644 index 00000000000..d3246758f25 --- /dev/null +++ b/tests/st/control/inner/test_330_for_after_for_in_if.py @@ -0,0 +1,68 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import context +from mindspore import Tensor, nn +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +grad_all = C.GradOperation(get_all=True) +context.set_context(device_target="Ascend") + +def test_for_after_for_in_if(): + class ForAfterForInIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + out = self.param_a + if self.param_a > self.param_b: + for _ in range(0, 4): + self.param_a += 1 + self.param_b -= 3 + self.param_b += 10 + for _ in range(0, 5): + x += self.param_b + out *= x + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(5, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + for_after_for_in_if_net = ForAfterForInIfNet() + net = GradNet(for_after_for_in_if_net) + graph_forward_res = for_after_for_in_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + for_after_for_in_if_net = ForAfterForInIfNet() + net = GradNet(for_after_for_in_if_net) + pynative_forward_res = for_after_for_in_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res