!28078 Disable incorporate geitem pass

Merge pull request !28078 from chenfei_mindspore/disbale-inorporate-getitem
This commit is contained in:
i-robot 2021-12-27 02:57:46 +00:00 committed by Gitee
commit c889cc20e2
4 changed files with 200 additions and 49 deletions

View File

@ -53,6 +53,12 @@ class VisitContext {
return true;
}
bool IndexVisited(int64_t index) {
return std::any_of(index_stacks_.begin(), index_stacks_.end(), [&index](const std::vector<int64_t> &index_stack) {
return !index_stack.empty() && index_stack.back() == index;
});
}
std::set<std::vector<int64_t>> index_stacks_;
};
using VisitContextPtr = std::shared_ptr<VisitContext>;
@ -66,17 +72,26 @@ class ContextManager {
bool AddContext(const AnfNodePtr &node, const std::vector<int64_t> &index_stack) {
auto it = contexts_.find(node);
if (it == contexts_.end()) {
MS_LOG(DEBUG) << "Add node: " << node->DebugString();
contexts_[node] = std::make_shared<VisitContext>(index_stack);
return true;
}
return it->second->Add(index_stack);
}
bool IndexVisited(const CNodePtr &node, int64_t index) {
auto it = contexts_.find(node);
if (it == contexts_.end()) {
return false;
}
return it->second->IndexVisited(index);
}
};
void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::vector<int64_t> index_stack, size_t seen,
ContextManager *context_manager) {
if (IS_OUTPUT_ON(DEBUG)) {
MS_LOG(DEBUG) << "Visit node:" << node->DebugString();
MS_LOG(WARNING) << "Visit node:" << node->DebugString();
for (size_t i = 0; i < index_stack.size(); i++) {
MS_LOG(DEBUG) << "index_stack[" << i << "]: " << index_stack[i];
}
@ -97,7 +112,9 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v
index_stack.push_back(output_idx);
auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
VisitNode(real_input, analyzer, index_stack, seen, context_manager);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
return;
}
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
// If make_tuple in make_tuple, visit may start with inner tuple_getitem.
if (index_stack.empty()) {
return;
@ -106,13 +123,17 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v
auto output_idx = index_stack.back();
index_stack.pop_back();
VisitNode(make_tuple->input(1 + output_idx), analyzer, index_stack, seen, context_manager);
} else if (IsFuncGraphCallNode(node)) {
return;
}
if (IsFuncGraphCallNode(node)) {
const auto &caller_func_graphs = analyzer.GetCallerFuncGraphs(node);
for (const auto &fg : caller_func_graphs) {
auto new_index_stack = std::vector<int64_t>(index_stack);
VisitNode(fg->output(), analyzer, new_index_stack, seen, context_manager);
}
} else if (node->isa<Parameter>()) {
return;
}
if (node->isa<Parameter>()) {
const auto &func_callers = analyzer.GetFuncGraphCallers(node->func_graph());
for (auto &caller : func_callers) {
const auto &args = analyzer.GetArg(node, caller);
@ -121,25 +142,13 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v
VisitNode(arg, analyzer, new_index_stack, seen, context_manager);
}
}
} else {
if (!index_stack.empty()) {
// TupleGetItem's input may not be a MakeTuple but a ValueTuple.
MS_LOG(DEBUG) << "Reach the end node: " << node->DebugString() << ", but index stack is not empty.";
}
return;
}
}
void EraseMakeTupleInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
// Don't eliminate the parameter of graph
if (node->isa<Parameter>()) {
MS_LOG(WARNING) << "Parameter:" << node->DebugString() << " is dead node and can't be erased.";
if (node->isa<ValueTuple>()) {
// TupleGetItem's input may not be a MakeTuple but a ValueTuple.
return;
}
auto new_tensor = NewValueNode(MakeValue(0));
auto abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0));
new_tensor->set_abstract(abs);
func_graph->manager()->Replace(node, new_tensor);
MS_LOG(DEBUG) << "Reach the end node: " << node->DebugString() << ", but index stack is not empty.";
}
std::vector<AnfNodePtr> GenerateOutputTempGetItems(const FuncGraphPtr &func_graph) {
@ -180,17 +189,111 @@ bool IsScalarValueNode(const AnfNodePtr &node) {
return node->abstract()->isa<abstract::AbstractScalar>();
}
bool EliminateDeadNode(const FuncGraphPtr &func_graph) {
std::vector<AnfNodePtr> tuple_getitem_nodes;
std::vector<AnfNodePtr> make_tuple_nodes;
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
for (const auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
tuple_getitem_nodes.emplace_back(node);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
make_tuple_nodes.emplace_back(node);
bool EraseMakeTupleInput(const FuncGraphPtr &func_graph, const CNodePtr &make_tuple, size_t input_idx) {
// Scalar(int) no need convert to Scalar(0), and Scalar(0) cannot be erased once again.
auto node = make_tuple->input(input_idx);
if (IsScalarValueNode(node)) {
return false;
}
MS_LOG(WARNING) << "Erase dead node: " << node->DebugString() << ", user make_tuple: " << make_tuple->DebugString();
auto new_tensor = NewValueNode(MakeValue(0));
auto abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0));
new_tensor->set_abstract(abs);
// Can't use `Replace`, must user `SetEdge`.
func_graph->manager()->SetEdge(make_tuple, input_idx, new_tensor);
return true;
}
void VisitValue(const ValuePtr &value, std::vector<int64_t> indexes,
HashMap<ValuePtr, HashSet<int64_t>> *visited_values) {
MS_EXCEPTION_IF_NULL(value);
MS_LOG(DEBUG) << "Visit value:" << value->ToString();
if (indexes.empty()) {
MS_LOG(DEBUG) << "Indexes empty";
return;
}
const auto visit_index = indexes.back();
(*visited_values)[value].insert(visit_index);
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (LongToSize(visit_index) >= value_tuple->size()) {
MS_LOG(EXCEPTION) << "Index: " << visit_index << " out of range: " << value_tuple->size();
}
indexes.pop_back();
MS_LOG(DEBUG) << "Visit index: " << visit_index;
VisitValue(value_tuple->value()[LongToSize(visit_index)], indexes, visited_values);
}
std::pair<ValuePtr, abstract::AbstractBasePtr> EraseValue(const ValuePtr &value, const abstract::AbstractBasePtr &abs,
const HashMap<ValuePtr, HashSet<int64_t>> &visited_values,
bool need_erase) {
if (need_erase) {
auto new_value = MakeValue(0);
auto new_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0));
new_abs->set_value(new_value);
MS_LOG(WARNING) << "Erase value:" << value->ToString();
return {new_value, new_abs};
}
auto it = visited_values.find(value);
if (it == visited_values.end()) {
return {value, abs};
}
const auto &all_visit_index = it->second;
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abs_tuple);
auto new_elements = std::vector<ValuePtr>(value_tuple->value());
auto new_abstracts = std::vector<abstract::AbstractBasePtr>(abs_tuple->elements());
if (new_elements.size() != new_abstracts.size()) {
MS_LOG(EXCEPTION) << "Value size: " << new_elements.size()
<< " is not equal to abstract size: " << new_abstracts.size();
}
bool change = false;
for (size_t i = 0; i < value_tuple->value().size(); i++) {
auto value_i = new_elements[i];
auto abs_i = new_abstracts[i];
// Avoid repeatedly erase.
MS_LOG(WARNING) << "value_i:[" << i << "]: " << value_i->ToString();
if (value_i->isa<Scalar>()) {
continue;
}
bool need_erase_i = all_visit_index.find(SizeToLong(i)) == all_visit_index.end();
auto [ret_value, ret_abs] = EraseValue(value_i, abs_i, visited_values, need_erase_i);
if (ret_value != value_i) {
new_elements[i] = ret_value;
new_abstracts[i] = ret_abs;
change = true;
}
}
if (change) {
value_tuple = std::make_shared<ValueTuple>(new_elements);
abs_tuple = std::make_shared<abstract::AbstractTuple>(new_abstracts);
abs_tuple->set_value(value_tuple);
}
return {value_tuple, abs_tuple};
}
bool EraseValueTuple(const AnfNodePtr &node, const std::set<std::vector<int64_t>> &contexts) {
HashMap<ValuePtr, HashSet<int64_t>> visited_values;
const auto value = GetValueNode(node);
for (const auto &context : contexts) {
VisitValue(value, context, &visited_values);
}
// Erase the unvisited values.
auto [new_value, new_abs] = EraseValue(value, node->abstract(), visited_values, false);
if (new_value != value) {
node->cast<ValueNodePtr>()->set_value(new_value);
node->set_abstract(new_abs);
MS_LOG(DEBUG) << "Set new value of node: " << node->DebugString();
return true;
}
return false;
}
bool EliminateDeadNode(const FuncGraphPtr &func_graph) {
// Travers all tuple getitem nodes to visit.
FuncGraphAnalyzer analyzer(func_graph);
analyzer.Run();
@ -198,31 +301,63 @@ bool EliminateDeadNode(const FuncGraphPtr &func_graph) {
if (!analyzer.HasIncorporateCall()) {
return false;
}
auto seen = NewSeenGeneration();
std::vector<int64_t> index_stack;
ContextManager context_manager;
// Visit from all tuple_getitem.
for (const auto &tuple_getitem : tuple_getitem_nodes) {
VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
}
// Visit from root graph output.
const auto &output_getitems = GenerateOutputTempGetItems(func_graph);
for (const auto &tuple_getitem : output_getitems) {
VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
}
// Check all make tuple's input
bool change = false;
for (const auto &make_tuple : make_tuple_nodes) {
auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
for (size_t i = 1; i < make_tuple_cnode->size(); i++) {
const auto &input = make_tuple_cnode->input(i);
// If make_tuple was not visited ,it may be a make tuple of swith_layer or addn and some other ops.
if (input->seen_ != seen && make_tuple_cnode->seen_ == seen && !IsScalarValueNode(input)) {
MS_LOG(INFO) << "Find dead node: " << input->DebugString();
change = true;
EraseMakeTupleInput(func_graph, input);
bool cycle_change = true;
while (cycle_change) {
ContextManager context_manager;
std::vector<AnfNodePtr> tuple_getitem_nodes;
std::vector<AnfNodePtr> make_tuple_nodes;
std::vector<AnfNodePtr> value_tuples;
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
for (const auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
tuple_getitem_nodes.emplace_back(node);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
make_tuple_nodes.emplace_back(node);
} else if (IsValueNode<ValueTuple>(node)) {
value_tuples.emplace_back(node);
}
}
// Visit from all tuple_getitem.
for (const auto &tuple_getitem : tuple_getitem_nodes) {
VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
}
// Visit from root graph output.
const auto &output_getitems = GenerateOutputTempGetItems(func_graph);
for (const auto &tuple_getitem : output_getitems) {
VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
}
// Check all make tuple's input
cycle_change = false;
for (const auto &make_tuple : make_tuple_nodes) {
MS_LOG(WARNING) << "Check make_tuple:" << make_tuple->DebugString();
auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
for (size_t i = 1; i < make_tuple_cnode->size(); i++) {
// If make_tuple was not visited ,it may be a make tuple of swith_layer or addn and some other ops.
auto input_edge_visited = context_manager.IndexVisited(make_tuple_cnode, i - 1);
// Can use `context_manager.contexts_.find(make_tuple_cnode) != context_manager.contexts_.end()`.
auto make_tuple_visited = make_tuple_cnode->seen_ == seen;
MS_LOG(WARNING) << "Check [" << i - 1 << "]:"
<< ", input_edge_visited: " << input_edge_visited
<< ", make_tuple_visited: " << make_tuple_visited;
if (!input_edge_visited && make_tuple_visited) {
cycle_change = EraseMakeTupleInput(func_graph, make_tuple_cnode, i) || cycle_change;
}
}
}
// Check all value tuple
for (const auto &value_tuple : value_tuples) {
auto it = context_manager.contexts_.find(value_tuple);
if (it == context_manager.contexts_.end()) {
continue;
}
cycle_change = EraseValueTuple(value_tuple, it->second->index_stacks_) || cycle_change;
}
change = change || cycle_change;
}
return change;
}

View File

@ -25,7 +25,7 @@ class EliminateDeadNodePass {
EliminateDeadNodePass() = default;
~EliminateDeadNodePass() = default;
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
MS_LOG(INFO) << "Closure enable:" << enable_closure;
if (!enable_closure) {
return false;

View File

@ -304,6 +304,10 @@ class IncorporateEnvGetitem : public AnfVisitor {
~IncorporateEnvGetitem() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
if (enable_closure) {
return nullptr;
}
is_match_ = false;
auto IsGCNode = [](const AnfNodePtr &node) -> bool {
auto cnode = node->cast<CNodePtr>();
@ -357,6 +361,10 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
~IncorporateEnvGetitemSwitch() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
if (enable_closure) {
return nullptr;
}
is_match_ = false;
auto IsSwNode = [](const AnfNodePtr &node) -> bool {
auto cnode = node->cast<CNodePtr>();
@ -418,6 +426,10 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
~IncorporateEnvGetitemSwitchLayer() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
if (enable_closure) {
return nullptr;
}
is_match_ = false;
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
if (!is_match_ || node->func_graph() == nullptr) {

View File

@ -1070,6 +1070,10 @@ class IncorporateGetitemSet : public OptimizerCaller {
~IncorporateGetitemSet() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
if (enable_closure) {
return nullptr;
}
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = (*eliminater)(optimizer, node);