From 16466f3453adba57a59d07a64c7d7d75406d83b2 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Mon, 26 Jul 2021 16:23:56 +0800 Subject: [PATCH] Fix executive order problem: the user of MakeTuple(Load, ...) do not attach UpdateState --- .../jit/static_analysis/order_enforce.cc | 106 +++++++++++++++++- ..._adam.py => test_auto_monad_expression.py} | 40 +++++++ 2 files changed, 140 insertions(+), 6 deletions(-) rename tests/st/auto_monad/{test_auto_monad_addn_adam.py => test_auto_monad_expression.py} (71%) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc index bd757f1b3e5..74a71f7ee18 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc @@ -37,7 +37,13 @@ class OrderEnforcer { void Run() { auto nodes = MakeTopoSortMap(); for (auto &node : nodes) { - HandleNode(node); + if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) { + HandleUpdateState(node); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + // op(MakTuple(Load, ...)) sometimes do not attach update_state, + // So need special treatment in order to ensure the exec_order of MakeTuple users. + HandleMakeTupleUsers(node); + } } } @@ -50,11 +56,7 @@ class OrderEnforcer { return nodes; } - void HandleNode(const AnfNodePtr &node) { - if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) { - // Skip nodes other than UpdateState. - return; - } + void HandleUpdateState(const AnfNodePtr &node) { auto update_state = node->cast(); MS_EXCEPTION_IF_NULL(update_state); const size_t update_state_inputs_size = 3; @@ -74,6 +76,98 @@ class OrderEnforcer { } } + bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) { + auto inputs = cnode->inputs(); + for (size_t index = 1; index < inputs.size(); index++) { + auto input = cnode->input(index); + if (IsPrimitiveCNode(input, prim::kPrimLoad)) { + return true; + } + } + return false; + } + + std::vector FindUpdateStateUsers(const CNodePtr &cnode) { + auto &node_users = manager_->node_users(); + auto iter = node_users.find(cnode); + if (iter == node_users.end()) { + return {}; + } + std::vector update_states; + auto &users = iter->second; + for (auto &user : users) { + auto &user_node = user.first; + if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) { + update_states.emplace_back(user_node); + } else if (IsPrimitiveCNode(user_node, prim::kPrimMakeTuple)) { + auto make_tuple_users = FindUpdateStateUsers(user_node->cast()); + for (auto make_tuple_user : make_tuple_users) { + if (IsPrimitiveCNode(make_tuple_user, prim::kPrimUpdateState)) { + update_states.emplace_back(make_tuple_user); + } + } + } + } + return update_states; + } + + AnfNodePtr FindLastUpdateState(const CNodePtr &cnode) { + auto inputs = cnode->inputs(); + std::vector all_update_states; + for (size_t index = 1; index < inputs.size(); index++) { + auto input = cnode->input(index); + if (IsPrimitiveCNode(input, prim::kPrimLoad)) { + std::vector update_states = FindUpdateStateUsers(input->cast()); + std::copy(update_states.begin(), update_states.end(), std::back_inserter(all_update_states)); + } + } + AnfNodePtr last_update_state = nullptr; + if (all_update_states.empty()) { + return last_update_state; + } + if (all_update_states.size() == 1) { + return all_update_states[0]; + } + for (size_t i = 0; i < all_update_states.size() - 1; i++) { + auto cur_update_state = all_update_states[i]; + auto next_update_state = all_update_states[i + 1]; + if (topo_sort_map_[cur_update_state] <= topo_sort_map_[next_update_state]) { + last_update_state = next_update_state; + } + } + return last_update_state; + } + + // Convert: + // load1 = Load(para1, u1) + // load2 = Load(para2, u2) + // maketuple1 = MakeTuple(inputs, load1, load2) + // addn = AddN(maketupe1) or other-op + // maketuple2 = MakeTuple(load1, load2) + // u3 = UpdateState(u', maketuple2) + // assign = Assign(para2, inputs, u3) + // To: + // load1 = Load(para1, u1) + // load2 = Load(para2, u2) + // maketuple1 = MakeTuple(inputs, load1, load2) + // addn = AddN(maketupe1) or other-op + // maketuple2 = MakeTuple(load1, load2) + // u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs + // assign = Assign(para2, inputs, u3) + void HandleMakeTupleUsers(const AnfNodePtr &node) { + auto maketuple = node->cast(); + MS_EXCEPTION_IF_NULL(maketuple); + if (CheckMakeTupleHaveLoad(maketuple)) { + auto update_state = FindLastUpdateState(maketuple); + if (update_state != nullptr) { + std::unordered_set maketuple_users = GetSpecialOperatorRealUsers(maketuple); + auto update_state_cnode = update_state->cast(); + MS_EXCEPTION_IF_NULL(update_state_cnode); + AddInputEdges(update_state_cnode, maketuple_users); + } + } + } + bool IsRef(const AnfNodePtr &node) { auto &abs = node->abstract(); return abs != nullptr && abs->isa(); diff --git a/tests/st/auto_monad/test_auto_monad_addn_adam.py b/tests/st/auto_monad/test_auto_monad_expression.py similarity index 71% rename from tests/st/auto_monad/test_auto_monad_addn_adam.py rename to tests/st/auto_monad/test_auto_monad_expression.py index 6a2013523a3..b714ca1c5fd 100644 --- a/tests/st/auto_monad/test_auto_monad_addn_adam.py +++ b/tests/st/auto_monad/test_auto_monad_expression.py @@ -81,3 +81,43 @@ def test_auto_monad_addn_adam(): allclose_nparray(new_var_pyn.asnumpy(), new_var.asnumpy(), 0.001, 0.001) allclose_nparray(new_m_pyn.asnumpy(), new_m.asnumpy(), 0.001, 0.001) allclose_nparray(new_v_pyn.asnumpy(), new_v.asnumpy(), 0.001, 0.001) + + +class AutoMonadTwoAssignTwoAddnDependencyNet(Cell): + def __init__(self): + super().__init__() + self.parameter1 = ms.Parameter(Tensor([1.0], ms.float32), name="parameter1") + self.parameter2 = ms.Parameter(Tensor([3.0], ms.float32), name="parameter2") + self.assign = P.Assign() + self.addN = P.AddN() + + def construct(self, inputs): + self.assign(self.parameter1, inputs) + out = self.addN((inputs, self.parameter1, self.parameter2)) + self.assign(self.parameter2, inputs) + out = self.addN((out, self.parameter1, self.parameter2)) + return out + + +class AutoMonadTwoAssignTwoAddnDependencyBenchmarkNet(Cell): + def __init__(self): + super().__init__() + self.parameter2 = ms.Parameter(Tensor([3.0], ms.float32), name="parameter2") + self.addN = P.AddN() + + def construct(self, inputs): + out = self.addN((inputs, inputs, self.parameter2)) + out = self.addN((out, inputs, inputs)) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_auto_monad_read_dependency_two_assign_two_addn(): + net = AutoMonadTwoAssignTwoAddnDependencyNet() + benchmarknet = AutoMonadTwoAssignTwoAddnDependencyBenchmarkNet() + out1 = net(Tensor([9.0], ms.float32)) + out2 = benchmarknet(Tensor([9.0], ms.float32)) + allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001)