!21037 optimize CommonSubexpressionElimination pass

Merge pull request !21037 from yuchaojie/vm_pass
This commit is contained in:
i-robot 2021-07-30 01:40:28 +00:00 committed by Gitee
commit 5d9e8da8da
4 changed files with 55 additions and 44 deletions

View File

@ -102,6 +102,11 @@ bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bo
return false;
}
bool BackendCSE::Cse(const FuncGraphPtr graph, const FuncGraphManagerPtr manager) const {
MS_EXCEPTION_IF_NULL(manager);
return BuildOrderGroupAndDoReplaceForOneGraph(graph, manager);
}
bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto backend_cse = std::make_shared<BackendCSE>();

View File

@ -35,6 +35,7 @@ class BackendCSE : public CSE {
virtual bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const;
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
virtual bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const;
bool Cse(const FuncGraphPtr graph, const FuncGraphManagerPtr manager) const override;
};
} // namespace opt
} // namespace mindspore

View File

@ -67,53 +67,55 @@ BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
return node_abs;
}
bool CSE::BuildOrderGroupAndDoReplaceForOneGraph(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) const {
MS_EXCEPTION_IF_NULL(fg);
std::vector<std::size_t> order_group;
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
std::unordered_map<AnfNodePtr, std::size_t> hashes;
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
for (auto node : toposet) {
MS_EXCEPTION_IF_NULL(node);
if (hashes.find(node) != hashes.end()) {
continue;
}
std::size_t h = 0;
if (node->isa<ValueNode>()) {
ValueNodePtr value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
size_t init = 0;
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
return hash_combine(hash, hashes[node_in]);
});
} else if (node->isa<Parameter>()) {
h = node->hash();
} else {
MS_LOG(ERROR) << "Unknown node type";
}
hashes[node] = h;
if (groups.find(h) == groups.end()) {
std::vector<AnfNodePtr> innervec({node});
groups[h] = innervec;
order_group.emplace_back(h);
} else {
groups[h].push_back(node);
}
}
return DoReplace(manager, order_group, &groups);
}
bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
bool changed = false;
for (FuncGraphPtr fg : manager->func_graphs()) {
MS_EXCEPTION_IF_NULL(fg);
std::vector<std::size_t> order_group;
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
std::unordered_map<AnfNodePtr, std::size_t> hashes;
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
for (auto node : toposet) {
MS_EXCEPTION_IF_NULL(node);
if (hashes.find(node) != hashes.end()) {
continue;
}
std::size_t h = 0;
if (node->isa<ValueNode>()) {
ValueNodePtr value_node = node->cast<ValueNodePtr>();
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
size_t init = 0;
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
return hash_combine(hash, hashes[node_in]);
});
} else if (node->isa<Parameter>()) {
h = node->hash();
} else {
MS_LOG(ERROR) << "Unknown node type";
}
hashes[node] = h;
if (groups.find(h) == groups.end()) {
std::vector<AnfNodePtr> innervec({node});
groups[h] = innervec;
order_group.emplace_back(h);
} else {
groups[h].push_back(node);
}
}
changed = DoReplace(manager, order_group, &groups) || changed;
changed = BuildOrderGroupAndDoReplaceForOneGraph(fg, manager) || changed;
}
return changed;
}

View File

@ -38,7 +38,10 @@ class CSE {
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;
bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;
virtual bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;
protected:
bool BuildOrderGroupAndDoReplaceForOneGraph(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) const;
private:
bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const;