From ca91c1c484221c3b14ee4189bc3ce8855805bd3d Mon Sep 17 00:00:00 2001 From: chenfei Date: Thu, 23 Dec 2021 11:38:04 +0800 Subject: [PATCH] disbale incorporate getitem when ENABLE_CLOSURE!=1 add value tuple visit debug log --- .../frontend/optimizer/dead_node_eliminate.cc | 231 ++++++++++++++---- .../frontend/optimizer/dead_node_eliminate.h | 2 +- .../optimizer/irpass/env_item_eliminate.h | 12 + .../optimizer/irpass/incorporate_getitem.h | 4 + 4 files changed, 200 insertions(+), 49 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.cc index 38ce0b3829c..cb072841859 100644 --- a/mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.cc +++ b/mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.cc @@ -53,6 +53,12 @@ class VisitContext { return true; } + bool IndexVisited(int64_t index) { + return std::any_of(index_stacks_.begin(), index_stacks_.end(), [&index](const std::vector &index_stack) { + return !index_stack.empty() && index_stack.back() == index; + }); + } + std::set> index_stacks_; }; using VisitContextPtr = std::shared_ptr; @@ -66,17 +72,26 @@ class ContextManager { bool AddContext(const AnfNodePtr &node, const std::vector &index_stack) { auto it = contexts_.find(node); if (it == contexts_.end()) { + MS_LOG(DEBUG) << "Add node: " << node->DebugString(); contexts_[node] = std::make_shared(index_stack); return true; } return it->second->Add(index_stack); } + + bool IndexVisited(const CNodePtr &node, int64_t index) { + auto it = contexts_.find(node); + if (it == contexts_.end()) { + return false; + } + return it->second->IndexVisited(index); + } }; void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::vector index_stack, size_t seen, ContextManager *context_manager) { if (IS_OUTPUT_ON(DEBUG)) { - MS_LOG(DEBUG) << "Visit node:" << node->DebugString(); + MS_LOG(WARNING) << "Visit node:" << node->DebugString(); for (size_t i = 0; i < index_stack.size(); i++) { MS_LOG(DEBUG) << "index_stack[" << i << "]: " << index_stack[i]; } @@ -97,7 +112,9 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v index_stack.push_back(output_idx); auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); VisitNode(real_input, analyzer, index_stack, seen, context_manager); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + return; + } + if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { // If make_tuple in make_tuple, visit may start with inner tuple_getitem. if (index_stack.empty()) { return; @@ -106,13 +123,17 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v auto output_idx = index_stack.back(); index_stack.pop_back(); VisitNode(make_tuple->input(1 + output_idx), analyzer, index_stack, seen, context_manager); - } else if (IsFuncGraphCallNode(node)) { + return; + } + if (IsFuncGraphCallNode(node)) { const auto &caller_func_graphs = analyzer.GetCallerFuncGraphs(node); for (const auto &fg : caller_func_graphs) { auto new_index_stack = std::vector(index_stack); VisitNode(fg->output(), analyzer, new_index_stack, seen, context_manager); } - } else if (node->isa()) { + return; + } + if (node->isa()) { const auto &func_callers = analyzer.GetFuncGraphCallers(node->func_graph()); for (auto &caller : func_callers) { const auto &args = analyzer.GetArg(node, caller); @@ -121,25 +142,13 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v VisitNode(arg, analyzer, new_index_stack, seen, context_manager); } } - } else { - if (!index_stack.empty()) { - // TupleGetItem's input may not be a MakeTuple but a ValueTuple. - MS_LOG(DEBUG) << "Reach the end node: " << node->DebugString() << ", but index stack is not empty."; - } return; } -} - -void EraseMakeTupleInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - // Don't eliminate the parameter of graph - if (node->isa()) { - MS_LOG(WARNING) << "Parameter:" << node->DebugString() << " is dead node and can't be erased."; + if (node->isa()) { + // TupleGetItem's input may not be a MakeTuple but a ValueTuple. return; } - auto new_tensor = NewValueNode(MakeValue(0)); - auto abs = std::make_shared(std::make_shared(0)); - new_tensor->set_abstract(abs); - func_graph->manager()->Replace(node, new_tensor); + MS_LOG(DEBUG) << "Reach the end node: " << node->DebugString() << ", but index stack is not empty."; } std::vector GenerateOutputTempGetItems(const FuncGraphPtr &func_graph) { @@ -180,17 +189,111 @@ bool IsScalarValueNode(const AnfNodePtr &node) { return node->abstract()->isa(); } -bool EliminateDeadNode(const FuncGraphPtr &func_graph) { - std::vector tuple_getitem_nodes; - std::vector make_tuple_nodes; - const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude); - for (const auto &node : all_nodes) { - if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - tuple_getitem_nodes.emplace_back(node); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - make_tuple_nodes.emplace_back(node); +bool EraseMakeTupleInput(const FuncGraphPtr &func_graph, const CNodePtr &make_tuple, size_t input_idx) { + // Scalar(int) no need convert to Scalar(0), and Scalar(0) cannot be erased once again. + auto node = make_tuple->input(input_idx); + if (IsScalarValueNode(node)) { + return false; + } + MS_LOG(WARNING) << "Erase dead node: " << node->DebugString() << ", user make_tuple: " << make_tuple->DebugString(); + auto new_tensor = NewValueNode(MakeValue(0)); + auto abs = std::make_shared(std::make_shared(0)); + new_tensor->set_abstract(abs); + // Can't use `Replace`, must user `SetEdge`. + func_graph->manager()->SetEdge(make_tuple, input_idx, new_tensor); + return true; +} + +void VisitValue(const ValuePtr &value, std::vector indexes, + HashMap> *visited_values) { + MS_EXCEPTION_IF_NULL(value); + MS_LOG(DEBUG) << "Visit value:" << value->ToString(); + if (indexes.empty()) { + MS_LOG(DEBUG) << "Indexes empty"; + return; + } + const auto visit_index = indexes.back(); + (*visited_values)[value].insert(visit_index); + auto value_tuple = value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + if (LongToSize(visit_index) >= value_tuple->size()) { + MS_LOG(EXCEPTION) << "Index: " << visit_index << " out of range: " << value_tuple->size(); + } + indexes.pop_back(); + MS_LOG(DEBUG) << "Visit index: " << visit_index; + VisitValue(value_tuple->value()[LongToSize(visit_index)], indexes, visited_values); +} + +std::pair EraseValue(const ValuePtr &value, const abstract::AbstractBasePtr &abs, + const HashMap> &visited_values, + bool need_erase) { + if (need_erase) { + auto new_value = MakeValue(0); + auto new_abs = std::make_shared(std::make_shared(0)); + new_abs->set_value(new_value); + MS_LOG(WARNING) << "Erase value:" << value->ToString(); + return {new_value, new_abs}; + } + auto it = visited_values.find(value); + if (it == visited_values.end()) { + return {value, abs}; + } + const auto &all_visit_index = it->second; + + auto value_tuple = value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + auto abs_tuple = abs->cast(); + MS_EXCEPTION_IF_NULL(abs_tuple); + auto new_elements = std::vector(value_tuple->value()); + auto new_abstracts = std::vector(abs_tuple->elements()); + if (new_elements.size() != new_abstracts.size()) { + MS_LOG(EXCEPTION) << "Value size: " << new_elements.size() + << " is not equal to abstract size: " << new_abstracts.size(); + } + + bool change = false; + for (size_t i = 0; i < value_tuple->value().size(); i++) { + auto value_i = new_elements[i]; + auto abs_i = new_abstracts[i]; + // Avoid repeatedly erase. + MS_LOG(WARNING) << "value_i:[" << i << "]: " << value_i->ToString(); + if (value_i->isa()) { + continue; + } + bool need_erase_i = all_visit_index.find(SizeToLong(i)) == all_visit_index.end(); + auto [ret_value, ret_abs] = EraseValue(value_i, abs_i, visited_values, need_erase_i); + if (ret_value != value_i) { + new_elements[i] = ret_value; + new_abstracts[i] = ret_abs; + change = true; } } + if (change) { + value_tuple = std::make_shared(new_elements); + abs_tuple = std::make_shared(new_abstracts); + abs_tuple->set_value(value_tuple); + } + return {value_tuple, abs_tuple}; +} + +bool EraseValueTuple(const AnfNodePtr &node, const std::set> &contexts) { + HashMap> visited_values; + const auto value = GetValueNode(node); + for (const auto &context : contexts) { + VisitValue(value, context, &visited_values); + } + // Erase the unvisited values. + auto [new_value, new_abs] = EraseValue(value, node->abstract(), visited_values, false); + if (new_value != value) { + node->cast()->set_value(new_value); + node->set_abstract(new_abs); + MS_LOG(DEBUG) << "Set new value of node: " << node->DebugString(); + return true; + } + return false; +} + +bool EliminateDeadNode(const FuncGraphPtr &func_graph) { // Travers all tuple getitem nodes to visit. FuncGraphAnalyzer analyzer(func_graph); analyzer.Run(); @@ -198,31 +301,63 @@ bool EliminateDeadNode(const FuncGraphPtr &func_graph) { if (!analyzer.HasIncorporateCall()) { return false; } + auto seen = NewSeenGeneration(); std::vector index_stack; - ContextManager context_manager; - // Visit from all tuple_getitem. - for (const auto &tuple_getitem : tuple_getitem_nodes) { - VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager); - } - // Visit from root graph output. - const auto &output_getitems = GenerateOutputTempGetItems(func_graph); - for (const auto &tuple_getitem : output_getitems) { - VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager); - } - // Check all make tuple's input bool change = false; - for (const auto &make_tuple : make_tuple_nodes) { - auto make_tuple_cnode = make_tuple->cast(); - for (size_t i = 1; i < make_tuple_cnode->size(); i++) { - const auto &input = make_tuple_cnode->input(i); - // If make_tuple was not visited ,it may be a make tuple of swith_layer or addn and some other ops. - if (input->seen_ != seen && make_tuple_cnode->seen_ == seen && !IsScalarValueNode(input)) { - MS_LOG(INFO) << "Find dead node: " << input->DebugString(); - change = true; - EraseMakeTupleInput(func_graph, input); + bool cycle_change = true; + while (cycle_change) { + ContextManager context_manager; + std::vector tuple_getitem_nodes; + std::vector make_tuple_nodes; + std::vector value_tuples; + const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude); + for (const auto &node : all_nodes) { + if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + tuple_getitem_nodes.emplace_back(node); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + make_tuple_nodes.emplace_back(node); + } else if (IsValueNode(node)) { + value_tuples.emplace_back(node); } } + // Visit from all tuple_getitem. + for (const auto &tuple_getitem : tuple_getitem_nodes) { + VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager); + } + // Visit from root graph output. + const auto &output_getitems = GenerateOutputTempGetItems(func_graph); + for (const auto &tuple_getitem : output_getitems) { + VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager); + } + // Check all make tuple's input + cycle_change = false; + for (const auto &make_tuple : make_tuple_nodes) { + MS_LOG(WARNING) << "Check make_tuple:" << make_tuple->DebugString(); + auto make_tuple_cnode = make_tuple->cast(); + for (size_t i = 1; i < make_tuple_cnode->size(); i++) { + // If make_tuple was not visited ,it may be a make tuple of swith_layer or addn and some other ops. + auto input_edge_visited = context_manager.IndexVisited(make_tuple_cnode, i - 1); + // Can use `context_manager.contexts_.find(make_tuple_cnode) != context_manager.contexts_.end()`. + auto make_tuple_visited = make_tuple_cnode->seen_ == seen; + MS_LOG(WARNING) << "Check [" << i - 1 << "]:" + << ", input_edge_visited: " << input_edge_visited + << ", make_tuple_visited: " << make_tuple_visited; + + if (!input_edge_visited && make_tuple_visited) { + cycle_change = EraseMakeTupleInput(func_graph, make_tuple_cnode, i) || cycle_change; + } + } + } + // Check all value tuple + for (const auto &value_tuple : value_tuples) { + auto it = context_manager.contexts_.find(value_tuple); + if (it == context_manager.contexts_.end()) { + continue; + } + cycle_change = EraseValueTuple(value_tuple, it->second->index_stacks_) || cycle_change; + } + change = change || cycle_change; } return change; } diff --git a/mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.h b/mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.h index 09f635fa715..c80471e4887 100644 --- a/mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/dead_node_eliminate.h @@ -25,7 +25,7 @@ class EliminateDeadNodePass { EliminateDeadNodePass() = default; ~EliminateDeadNodePass() = default; bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { - bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; + static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; MS_LOG(INFO) << "Closure enable:" << enable_closure; if (!enable_closure) { return false; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h index 59634210c73..fc7f358e295 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -304,6 +304,10 @@ class IncorporateEnvGetitem : public AnfVisitor { ~IncorporateEnvGetitem() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; + if (enable_closure) { + return nullptr; + } is_match_ = false; auto IsGCNode = [](const AnfNodePtr &node) -> bool { auto cnode = node->cast(); @@ -357,6 +361,10 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor { ~IncorporateEnvGetitemSwitch() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; + if (enable_closure) { + return nullptr; + } is_match_ = false; auto IsSwNode = [](const AnfNodePtr &node) -> bool { auto cnode = node->cast(); @@ -418,6 +426,10 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor { ~IncorporateEnvGetitemSwitchLayer() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; + if (enable_closure) { + return nullptr; + } is_match_ = false; AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode, IsNode})(node); if (!is_match_ || node->func_graph() == nullptr) { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index e1feae82857..758d2d7bf51 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -1070,6 +1070,10 @@ class IncorporateGetitemSet : public OptimizerCaller { ~IncorporateGetitemSet() = default; AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1"; + if (enable_closure) { + return nullptr; + } AnfNodePtr new_node; for (auto &eliminater : eliminaters_) { new_node = (*eliminater)(optimizer, node);