From a903cad5b379bdef480a6e6191b2e6d2e6793169 Mon Sep 17 00:00:00 2001 From: Margaret_wangrui Date: Fri, 3 Sep 2021 10:12:43 +0800 Subject: [PATCH] Revert "Dealing with the random accuracy problem: parameter and load are equivalent." This reverts commit 1a2f7e26639f701bb0c8405b52d6c9e567f340df. --- .../optimizer/auto_monad_eliminate.cc | 62 +++++++------------ 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc index 8f7ac74fe48..8137a6207d5 100644 --- a/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/auto_monad_eliminate.cc @@ -31,13 +31,12 @@ std::unordered_set GetAllSubGraphs(const std::unordered_setcast(); MS_EXCEPTION_IF_NULL(cnode); - auto fg_idx = IsValueNode(cnode->input(0)) ? 0 : 1; + auto fg_idx = IsPrimitiveCNode(cnode, prim::kPrimCall) ? 0 : 1; auto fg_value_node = cnode->input(fg_idx)->cast(); MS_EXCEPTION_IF_NULL(fg_value_node); auto value = fg_value_node->value(); MS_EXCEPTION_IF_NULL(value); auto caller_fg = value->cast(); - graphs.insert(caller_fg); auto sub_graphs = caller_fg->func_graphs_used_total(); for (auto sub_graph : sub_graphs) { graphs.insert(sub_graph); @@ -47,28 +46,18 @@ std::unordered_set GetAllSubGraphs(const std::unordered_set &call_partial, - const AnfNodePtr ¶m) { + const AnfNodePtr &load_param) { if (call_partial.empty()) { return false; } auto manager = fg->manager(); - auto param_users = manager->node_users()[param]; - std::unordered_set all_users; - for (auto param_user : param_users) { - all_users.insert(param_user.first); - if (IsPrimitiveCNode(param_user.first, prim::kPrimLoad)) { - auto load_users = manager->node_users()[param_user.first]; - for (auto load_user : load_users) { - all_users.insert(load_user.first); - } - } - } + auto load_param_users = manager->node_users()[load_param]; std::unordered_set sub_graphs = GetAllSubGraphs(call_partial); - for (auto user : all_users) { - if (!user->isa()) { + for (auto user : load_param_users) { + if (!user.first->isa()) { continue; } - auto node = user->cast(); + auto node = user.first->cast(); auto user_graph = node->func_graph(); // Check if user graph is in sub graphs. bool exist_user_graph = std::any_of(sub_graphs.begin(), sub_graphs.end(), @@ -87,7 +76,7 @@ std::vector> GenerateLoadGroups(const FuncGraphPtr &fg, cons // Record the param and the user set of param in toposet nodes std::unordered_map> unload_users_record; std::unordered_set call_partial_nodes; - bool has_umonad_call_node_users = false; + std::unordered_map has_umonad_call_node_users; for (size_t i = 0; i < toposet.size(); i++) { auto &node = toposet[i]; auto cnode = node->cast(); @@ -95,23 +84,20 @@ std::vector> GenerateLoadGroups(const FuncGraphPtr &fg, cons continue; } // Record param user in toposort nodes. - if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { - for (const auto &input : cnode->inputs()) { - AnfNodePtr cur_param = nullptr; - if (input->isa()) { - cur_param = input; - } else if (IsPrimitiveCNode(input, prim::kPrimLoad)) { - cur_param = input->cast()->input(1); - } else if (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast()->input(1)->isa()) { - cur_param = input->cast()->input(1); - } - if (cur_param != nullptr) { - unload_users_record[cur_param].insert(cnode); - } + for (const auto &input : cnode->inputs()) { + AnfNodePtr cur_param = nullptr; + if (input->isa()) { + cur_param = input; + } else if (IsPrimitiveCNode(input, prim::kPrimLoad)) { + cur_param = input->cast()->input(1); + } else if (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast()->input(1)->isa()) { + cur_param = input->cast()->input(1); + } + if (cur_param != nullptr) { + unload_users_record[cur_param].insert(cnode); } } - if (IsPrimitiveCNode(cnode, prim::kPrimCall) || IsValueNode(cnode->input(0)) || - IsPrimitiveCNode(cnode, prim::kPrimPartial)) { + if (IsPrimitiveCNode(cnode, prim::kPrimCall) || IsPrimitiveCNode(cnode, prim::kPrimPartial)) { call_partial_nodes.insert(cnode); } @@ -129,11 +115,11 @@ std::vector> GenerateLoadGroups(const FuncGraphPtr &fg, cons load_groups.push_back({i}); // If had not user in toposort, should check if has call or partial user // If already has call node user, do not need check again. - if (!unload_users_record[load_param].empty()) { + if (!unload_users_record[load_param].empty() || has_umonad_call_node_users[load_param] == true) { continue; } - has_umonad_call_node_users = HasUMonadCallNodeUser(fg, call_partial_nodes, load_param); - if (!has_umonad_call_node_users) { + has_umonad_call_node_users[load_param] = HasUMonadCallNodeUser(fg, call_partial_nodes, load_param); + if (has_umonad_call_node_users[load_param] == false) { need_replace_loads->emplace_back(cnode); } } else { @@ -165,8 +151,8 @@ std::vector> SplitGroup(const std::vector &topos return false; } // if Call/Switch/SwitchLayer, do not replace load. - if (IsPrimitiveCNode(node, prim::kPrimCall) || IsValueNode(node->cast()->input(0)) || - IsPrimitiveCNode(node, prim::kPrimSwitch) || IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) { + if (IsPrimitiveCNode(node, prim::kPrimCall) || IsPrimitiveCNode(node, prim::kPrimSwitch) || + IsPrimitiveCNode(node, prim::kPrimSwitchLayer)) { return true; } auto cnode = node->cast();