From 2a85af5d8328e652e79a7f70cf370a8997ca8076 Mon Sep 17 00:00:00 2001 From: huangbingjian Date: Tue, 27 Apr 2021 17:24:55 +0800 Subject: [PATCH] 1. if Switch/SwitchLayer, do not replace Load or remove UpdateState; 2. add control flow testcases; 3. fix codedex problem --- .../ccsrc/frontend/optimizer/ad/dfunctor.cc | 2 +- .../optimizer/auto_monad_eliminate.cc | 5 ++ .../optimizer/irpass/updatestate_eliminate.cc | 8 +++ mindspore/ccsrc/frontend/optimizer/opt.cc | 2 +- mindspore/core/ir/tensor.cc | 2 +- tests/st/control/inner/test_030_for_in_if.py | 49 +++++++++++++++++++ .../st/control/inner/test_100_if_after_if.py | 25 ++++++++++ 7 files changed, 90 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 27e06fcb293..a1e7a0d29ff 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -880,7 +880,7 @@ static std::vector> FindPrimalJPair(const FuncGrap } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { // To find J user. auto j_user = GetJUser(node_user_map, cnode, index); - primal_j_pair.emplace_back(std::pair(nullptr, j_user)); + (void)primal_j_pair.emplace_back(std::pair(nullptr, j_user)); } } diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc index 2424df355e1..911fc092d92 100644 --- a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc @@ -85,6 +85,11 @@ std::vector> SplitGroup(const std::vector &topos if (IsPrimitiveCNode(node, prim::kPrimLoad)) { return false; } + // if Call/Switch/SwitchLayer, do not replace load. + if (IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) || + IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) { + return true; + } auto cnode = node->cast(); auto &inputs = cnode->inputs(); return std::any_of(inputs.begin(), inputs.end(), diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc index d5e52f56917..b0a62c0a6f1 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc @@ -28,6 +28,7 @@ namespace { // data = Load(input, attach) // data = Depend(input, attach) // monad = UpdateState(input, attach) +constexpr size_t kFirstInputIndex = 0; constexpr size_t kInputIndex = 1; constexpr size_t kAttachIndex = 2; constexpr size_t kMakeTupleSize = 3; @@ -120,6 +121,13 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A return nullptr; } } + // Skip Call/Switch/SwitchLayer. + auto first_input_node = cnode->input(kFirstInputIndex); + if (IsPrimitiveCNode(first_input_node, prim::kPrimCall) || IsPrimitiveCNode(first_input_node, prim::kPrimSwitch) || + IsPrimitiveCNode(first_input_node, prim::kPrimSwitchLayer)) { + return nullptr; + } + // Remove UpdateState by replace it with its input monad. return update_state->input(kInputIndex); } diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc index ccc6ca4b6bd..3b2c53bddcb 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.cc +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -246,7 +246,7 @@ void SubstitutionList::DisplayStatusOfSubstitution(const std::unordered_mapname_; - ss << std::left << std::setw(space + 4) << name << "\t"; + ss << std::left << std::setw(SizeToInt(space) + 4) << name << "\t"; for (auto change : status.at(name + std::to_string(i))) { ss << change << " "; } diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index d5f934ad5ec..6d0249e6109 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -393,7 +393,7 @@ class TensorDataImpl : public TensorData { pos++; } size_t len = pos - index; - std::string space(max_width - len, ' '); + std::string space(max_width - SizeToInt(len), ' '); str = str.replace(index, len, space); index = str.find('#', index); } diff --git a/tests/st/control/inner/test_030_for_in_if.py b/tests/st/control/inner/test_030_for_in_if.py index 6b01497253a..722ca4714f5 100644 --- a/tests/st/control/inner/test_030_for_in_if.py +++ b/tests/st/control/inner/test_030_for_in_if.py @@ -170,3 +170,52 @@ def test_for_in_if_03(): assert graph_forward_res == pynative_forward_res assert graph_backward_res == pynative_backward_res + + +def test_for_in_if_04(): + class ForInIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x): + out = self.param_a + x = self.func(x) + out *= x + return out + + def func(self, x): + if self.param_a > self.param_b: + for _ in range(0, 4): + self.param_a += 1 + self.param_b -= 3 + self.param_b += 10 + return x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + + x = Tensor(5, mstype.int32) + + # graph mode + context.set_context(mode=context.GRAPH_MODE) + for_in_if_net = ForInIfNet() + net = GradNet(for_in_if_net) + graph_forward_res = for_in_if_net(x) + graph_backward_res = net(x) + + # pynative mode + context.set_context(mode=context.PYNATIVE_MODE) + for_in_if_net = ForInIfNet() + net = GradNet(for_in_if_net) + pynative_forward_res = for_in_if_net(x) + pynative_backward_res = net(x) + + assert graph_forward_res == pynative_forward_res + assert graph_backward_res == pynative_backward_res diff --git a/tests/st/control/inner/test_100_if_after_if.py b/tests/st/control/inner/test_100_if_after_if.py index 6b03551cca4..f68af8cd58e 100644 --- a/tests/st/control/inner/test_100_if_after_if.py +++ b/tests/st/control/inner/test_100_if_after_if.py @@ -74,6 +74,25 @@ class IfAfterIfNet2(nn.Cell): return y +class IfAfterIfNet3(nn.Cell): + def __init__(self): + super().__init__() + self.param_a = Parameter(Tensor(5, mstype.int32), name='a') + self.param_b = Parameter(Tensor(4, mstype.int32), name='b') + + def construct(self, x, y): + out = x * y + self.func(self.param_b) + if self.param_a > self.param_b: + out += 5 + return out + + def func(self, x): + if self.param_a > self.param_b: + x += 5 + self.param_b += 4 + return x + + class GradNet(nn.Cell): def __init__(self, net): super(GradNet, self).__init__() @@ -118,3 +137,9 @@ def test_if_after_if_02(): x = Tensor(2, mstype.int32) y = Tensor(5, mstype.int32) control_flow_if_after_if(IfAfterIfNet2, x, y) + + +def test_if_after_if_03(): + x = Tensor(2, mstype.int32) + y = Tensor(5, mstype.int32) + control_flow_if_after_if(IfAfterIfNet3, x, y)