forked from mindspore-Ecosystem/mindspore
!28078 Disable incorporate geitem pass
Merge pull request !28078 from chenfei_mindspore/disbale-inorporate-getitem
This commit is contained in:
commit
c889cc20e2
|
@ -53,6 +53,12 @@ class VisitContext {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IndexVisited(int64_t index) {
|
||||
return std::any_of(index_stacks_.begin(), index_stacks_.end(), [&index](const std::vector<int64_t> &index_stack) {
|
||||
return !index_stack.empty() && index_stack.back() == index;
|
||||
});
|
||||
}
|
||||
|
||||
std::set<std::vector<int64_t>> index_stacks_;
|
||||
};
|
||||
using VisitContextPtr = std::shared_ptr<VisitContext>;
|
||||
|
@ -66,17 +72,26 @@ class ContextManager {
|
|||
bool AddContext(const AnfNodePtr &node, const std::vector<int64_t> &index_stack) {
|
||||
auto it = contexts_.find(node);
|
||||
if (it == contexts_.end()) {
|
||||
MS_LOG(DEBUG) << "Add node: " << node->DebugString();
|
||||
contexts_[node] = std::make_shared<VisitContext>(index_stack);
|
||||
return true;
|
||||
}
|
||||
return it->second->Add(index_stack);
|
||||
}
|
||||
|
||||
bool IndexVisited(const CNodePtr &node, int64_t index) {
|
||||
auto it = contexts_.find(node);
|
||||
if (it == contexts_.end()) {
|
||||
return false;
|
||||
}
|
||||
return it->second->IndexVisited(index);
|
||||
}
|
||||
};
|
||||
|
||||
void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::vector<int64_t> index_stack, size_t seen,
|
||||
ContextManager *context_manager) {
|
||||
if (IS_OUTPUT_ON(DEBUG)) {
|
||||
MS_LOG(DEBUG) << "Visit node:" << node->DebugString();
|
||||
MS_LOG(WARNING) << "Visit node:" << node->DebugString();
|
||||
for (size_t i = 0; i < index_stack.size(); i++) {
|
||||
MS_LOG(DEBUG) << "index_stack[" << i << "]: " << index_stack[i];
|
||||
}
|
||||
|
@ -97,7 +112,9 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v
|
|||
index_stack.push_back(output_idx);
|
||||
auto real_input = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem);
|
||||
VisitNode(real_input, analyzer, index_stack, seen, context_manager);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
return;
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
// If make_tuple in make_tuple, visit may start with inner tuple_getitem.
|
||||
if (index_stack.empty()) {
|
||||
return;
|
||||
|
@ -106,13 +123,17 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v
|
|||
auto output_idx = index_stack.back();
|
||||
index_stack.pop_back();
|
||||
VisitNode(make_tuple->input(1 + output_idx), analyzer, index_stack, seen, context_manager);
|
||||
} else if (IsFuncGraphCallNode(node)) {
|
||||
return;
|
||||
}
|
||||
if (IsFuncGraphCallNode(node)) {
|
||||
const auto &caller_func_graphs = analyzer.GetCallerFuncGraphs(node);
|
||||
for (const auto &fg : caller_func_graphs) {
|
||||
auto new_index_stack = std::vector<int64_t>(index_stack);
|
||||
VisitNode(fg->output(), analyzer, new_index_stack, seen, context_manager);
|
||||
}
|
||||
} else if (node->isa<Parameter>()) {
|
||||
return;
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
const auto &func_callers = analyzer.GetFuncGraphCallers(node->func_graph());
|
||||
for (auto &caller : func_callers) {
|
||||
const auto &args = analyzer.GetArg(node, caller);
|
||||
|
@ -121,25 +142,13 @@ void VisitNode(const AnfNodePtr &node, const FuncGraphAnalyzer &analyzer, std::v
|
|||
VisitNode(arg, analyzer, new_index_stack, seen, context_manager);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!index_stack.empty()) {
|
||||
// TupleGetItem's input may not be a MakeTuple but a ValueTuple.
|
||||
MS_LOG(DEBUG) << "Reach the end node: " << node->DebugString() << ", but index stack is not empty.";
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void EraseMakeTupleInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
// Don't eliminate the parameter of graph
|
||||
if (node->isa<Parameter>()) {
|
||||
MS_LOG(WARNING) << "Parameter:" << node->DebugString() << " is dead node and can't be erased.";
|
||||
if (node->isa<ValueTuple>()) {
|
||||
// TupleGetItem's input may not be a MakeTuple but a ValueTuple.
|
||||
return;
|
||||
}
|
||||
auto new_tensor = NewValueNode(MakeValue(0));
|
||||
auto abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0));
|
||||
new_tensor->set_abstract(abs);
|
||||
func_graph->manager()->Replace(node, new_tensor);
|
||||
MS_LOG(DEBUG) << "Reach the end node: " << node->DebugString() << ", but index stack is not empty.";
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> GenerateOutputTempGetItems(const FuncGraphPtr &func_graph) {
|
||||
|
@ -180,17 +189,111 @@ bool IsScalarValueNode(const AnfNodePtr &node) {
|
|||
return node->abstract()->isa<abstract::AbstractScalar>();
|
||||
}
|
||||
|
||||
bool EliminateDeadNode(const FuncGraphPtr &func_graph) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_nodes;
|
||||
std::vector<AnfNodePtr> make_tuple_nodes;
|
||||
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
|
||||
for (const auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
tuple_getitem_nodes.emplace_back(node);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
make_tuple_nodes.emplace_back(node);
|
||||
bool EraseMakeTupleInput(const FuncGraphPtr &func_graph, const CNodePtr &make_tuple, size_t input_idx) {
|
||||
// Scalar(int) no need convert to Scalar(0), and Scalar(0) cannot be erased once again.
|
||||
auto node = make_tuple->input(input_idx);
|
||||
if (IsScalarValueNode(node)) {
|
||||
return false;
|
||||
}
|
||||
MS_LOG(WARNING) << "Erase dead node: " << node->DebugString() << ", user make_tuple: " << make_tuple->DebugString();
|
||||
auto new_tensor = NewValueNode(MakeValue(0));
|
||||
auto abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0));
|
||||
new_tensor->set_abstract(abs);
|
||||
// Can't use `Replace`, must user `SetEdge`.
|
||||
func_graph->manager()->SetEdge(make_tuple, input_idx, new_tensor);
|
||||
return true;
|
||||
}
|
||||
|
||||
void VisitValue(const ValuePtr &value, std::vector<int64_t> indexes,
|
||||
HashMap<ValuePtr, HashSet<int64_t>> *visited_values) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
MS_LOG(DEBUG) << "Visit value:" << value->ToString();
|
||||
if (indexes.empty()) {
|
||||
MS_LOG(DEBUG) << "Indexes empty";
|
||||
return;
|
||||
}
|
||||
const auto visit_index = indexes.back();
|
||||
(*visited_values)[value].insert(visit_index);
|
||||
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
if (LongToSize(visit_index) >= value_tuple->size()) {
|
||||
MS_LOG(EXCEPTION) << "Index: " << visit_index << " out of range: " << value_tuple->size();
|
||||
}
|
||||
indexes.pop_back();
|
||||
MS_LOG(DEBUG) << "Visit index: " << visit_index;
|
||||
VisitValue(value_tuple->value()[LongToSize(visit_index)], indexes, visited_values);
|
||||
}
|
||||
|
||||
std::pair<ValuePtr, abstract::AbstractBasePtr> EraseValue(const ValuePtr &value, const abstract::AbstractBasePtr &abs,
|
||||
const HashMap<ValuePtr, HashSet<int64_t>> &visited_values,
|
||||
bool need_erase) {
|
||||
if (need_erase) {
|
||||
auto new_value = MakeValue(0);
|
||||
auto new_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(0));
|
||||
new_abs->set_value(new_value);
|
||||
MS_LOG(WARNING) << "Erase value:" << value->ToString();
|
||||
return {new_value, new_abs};
|
||||
}
|
||||
auto it = visited_values.find(value);
|
||||
if (it == visited_values.end()) {
|
||||
return {value, abs};
|
||||
}
|
||||
const auto &all_visit_index = it->second;
|
||||
|
||||
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
auto abs_tuple = abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_tuple);
|
||||
auto new_elements = std::vector<ValuePtr>(value_tuple->value());
|
||||
auto new_abstracts = std::vector<abstract::AbstractBasePtr>(abs_tuple->elements());
|
||||
if (new_elements.size() != new_abstracts.size()) {
|
||||
MS_LOG(EXCEPTION) << "Value size: " << new_elements.size()
|
||||
<< " is not equal to abstract size: " << new_abstracts.size();
|
||||
}
|
||||
|
||||
bool change = false;
|
||||
for (size_t i = 0; i < value_tuple->value().size(); i++) {
|
||||
auto value_i = new_elements[i];
|
||||
auto abs_i = new_abstracts[i];
|
||||
// Avoid repeatedly erase.
|
||||
MS_LOG(WARNING) << "value_i:[" << i << "]: " << value_i->ToString();
|
||||
if (value_i->isa<Scalar>()) {
|
||||
continue;
|
||||
}
|
||||
bool need_erase_i = all_visit_index.find(SizeToLong(i)) == all_visit_index.end();
|
||||
auto [ret_value, ret_abs] = EraseValue(value_i, abs_i, visited_values, need_erase_i);
|
||||
if (ret_value != value_i) {
|
||||
new_elements[i] = ret_value;
|
||||
new_abstracts[i] = ret_abs;
|
||||
change = true;
|
||||
}
|
||||
}
|
||||
if (change) {
|
||||
value_tuple = std::make_shared<ValueTuple>(new_elements);
|
||||
abs_tuple = std::make_shared<abstract::AbstractTuple>(new_abstracts);
|
||||
abs_tuple->set_value(value_tuple);
|
||||
}
|
||||
return {value_tuple, abs_tuple};
|
||||
}
|
||||
|
||||
bool EraseValueTuple(const AnfNodePtr &node, const std::set<std::vector<int64_t>> &contexts) {
|
||||
HashMap<ValuePtr, HashSet<int64_t>> visited_values;
|
||||
const auto value = GetValueNode(node);
|
||||
for (const auto &context : contexts) {
|
||||
VisitValue(value, context, &visited_values);
|
||||
}
|
||||
// Erase the unvisited values.
|
||||
auto [new_value, new_abs] = EraseValue(value, node->abstract(), visited_values, false);
|
||||
if (new_value != value) {
|
||||
node->cast<ValueNodePtr>()->set_value(new_value);
|
||||
node->set_abstract(new_abs);
|
||||
MS_LOG(DEBUG) << "Set new value of node: " << node->DebugString();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool EliminateDeadNode(const FuncGraphPtr &func_graph) {
|
||||
// Travers all tuple getitem nodes to visit.
|
||||
FuncGraphAnalyzer analyzer(func_graph);
|
||||
analyzer.Run();
|
||||
|
@ -198,31 +301,63 @@ bool EliminateDeadNode(const FuncGraphPtr &func_graph) {
|
|||
if (!analyzer.HasIncorporateCall()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto seen = NewSeenGeneration();
|
||||
std::vector<int64_t> index_stack;
|
||||
ContextManager context_manager;
|
||||
// Visit from all tuple_getitem.
|
||||
for (const auto &tuple_getitem : tuple_getitem_nodes) {
|
||||
VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
|
||||
}
|
||||
// Visit from root graph output.
|
||||
const auto &output_getitems = GenerateOutputTempGetItems(func_graph);
|
||||
for (const auto &tuple_getitem : output_getitems) {
|
||||
VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
|
||||
}
|
||||
// Check all make tuple's input
|
||||
bool change = false;
|
||||
for (const auto &make_tuple : make_tuple_nodes) {
|
||||
auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < make_tuple_cnode->size(); i++) {
|
||||
const auto &input = make_tuple_cnode->input(i);
|
||||
// If make_tuple was not visited ,it may be a make tuple of swith_layer or addn and some other ops.
|
||||
if (input->seen_ != seen && make_tuple_cnode->seen_ == seen && !IsScalarValueNode(input)) {
|
||||
MS_LOG(INFO) << "Find dead node: " << input->DebugString();
|
||||
change = true;
|
||||
EraseMakeTupleInput(func_graph, input);
|
||||
bool cycle_change = true;
|
||||
while (cycle_change) {
|
||||
ContextManager context_manager;
|
||||
std::vector<AnfNodePtr> tuple_getitem_nodes;
|
||||
std::vector<AnfNodePtr> make_tuple_nodes;
|
||||
std::vector<AnfNodePtr> value_tuples;
|
||||
const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
|
||||
for (const auto &node : all_nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
tuple_getitem_nodes.emplace_back(node);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
make_tuple_nodes.emplace_back(node);
|
||||
} else if (IsValueNode<ValueTuple>(node)) {
|
||||
value_tuples.emplace_back(node);
|
||||
}
|
||||
}
|
||||
// Visit from all tuple_getitem.
|
||||
for (const auto &tuple_getitem : tuple_getitem_nodes) {
|
||||
VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
|
||||
}
|
||||
// Visit from root graph output.
|
||||
const auto &output_getitems = GenerateOutputTempGetItems(func_graph);
|
||||
for (const auto &tuple_getitem : output_getitems) {
|
||||
VisitNode(tuple_getitem, analyzer, index_stack, seen, &context_manager);
|
||||
}
|
||||
// Check all make tuple's input
|
||||
cycle_change = false;
|
||||
for (const auto &make_tuple : make_tuple_nodes) {
|
||||
MS_LOG(WARNING) << "Check make_tuple:" << make_tuple->DebugString();
|
||||
auto make_tuple_cnode = make_tuple->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < make_tuple_cnode->size(); i++) {
|
||||
// If make_tuple was not visited ,it may be a make tuple of swith_layer or addn and some other ops.
|
||||
auto input_edge_visited = context_manager.IndexVisited(make_tuple_cnode, i - 1);
|
||||
// Can use `context_manager.contexts_.find(make_tuple_cnode) != context_manager.contexts_.end()`.
|
||||
auto make_tuple_visited = make_tuple_cnode->seen_ == seen;
|
||||
MS_LOG(WARNING) << "Check [" << i - 1 << "]:"
|
||||
<< ", input_edge_visited: " << input_edge_visited
|
||||
<< ", make_tuple_visited: " << make_tuple_visited;
|
||||
|
||||
if (!input_edge_visited && make_tuple_visited) {
|
||||
cycle_change = EraseMakeTupleInput(func_graph, make_tuple_cnode, i) || cycle_change;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check all value tuple
|
||||
for (const auto &value_tuple : value_tuples) {
|
||||
auto it = context_manager.contexts_.find(value_tuple);
|
||||
if (it == context_manager.contexts_.end()) {
|
||||
continue;
|
||||
}
|
||||
cycle_change = EraseValueTuple(value_tuple, it->second->index_stacks_) || cycle_change;
|
||||
}
|
||||
change = change || cycle_change;
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ class EliminateDeadNodePass {
|
|||
EliminateDeadNodePass() = default;
|
||||
~EliminateDeadNodePass() = default;
|
||||
bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
|
||||
bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
MS_LOG(INFO) << "Closure enable:" << enable_closure;
|
||||
if (!enable_closure) {
|
||||
return false;
|
||||
|
|
|
@ -304,6 +304,10 @@ class IncorporateEnvGetitem : public AnfVisitor {
|
|||
~IncorporateEnvGetitem() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
if (enable_closure) {
|
||||
return nullptr;
|
||||
}
|
||||
is_match_ = false;
|
||||
auto IsGCNode = [](const AnfNodePtr &node) -> bool {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -357,6 +361,10 @@ class IncorporateEnvGetitemSwitch : public AnfVisitor {
|
|||
~IncorporateEnvGetitemSwitch() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
if (enable_closure) {
|
||||
return nullptr;
|
||||
}
|
||||
is_match_ = false;
|
||||
auto IsSwNode = [](const AnfNodePtr &node) -> bool {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -418,6 +426,10 @@ class IncorporateEnvGetitemSwitchLayer : public AnfVisitor {
|
|||
~IncorporateEnvGetitemSwitchLayer() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
if (enable_closure) {
|
||||
return nullptr;
|
||||
}
|
||||
is_match_ = false;
|
||||
AnfVisitor::Match(prim::kPrimEnvGetItem, {IsCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
|
||||
if (!is_match_ || node->func_graph() == nullptr) {
|
||||
|
|
|
@ -1070,6 +1070,10 @@ class IncorporateGetitemSet : public OptimizerCaller {
|
|||
~IncorporateGetitemSet() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
static bool enable_closure = common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "1";
|
||||
if (enable_closure) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr new_node;
|
||||
for (auto &eliminater : eliminaters_) {
|
||||
new_node = (*eliminater)(optimizer, node);
|
||||
|
|
Loading…
Reference in New Issue