!21037 optimize CommonSubexpressionElimination pass
Merge pull request !21037 from yuchaojie/vm_pass
This commit is contained in:
commit
5d9e8da8da
|
@ -102,6 +102,11 @@ bool BackendCSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bo
|
|||
return false;
|
||||
}
|
||||
|
||||
bool BackendCSE::Cse(const FuncGraphPtr graph, const FuncGraphManagerPtr manager) const {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
return BuildOrderGroupAndDoReplaceForOneGraph(graph, manager);
|
||||
}
|
||||
|
||||
bool CommonSubexpressionElimination::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto backend_cse = std::make_shared<BackendCSE>();
|
||||
|
|
|
@ -35,6 +35,7 @@ class BackendCSE : public CSE {
|
|||
virtual bool CheckEqualCnodeInputs(const AnfNodePtr &main, const AnfNodePtr &node) const;
|
||||
bool CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node, bool check_side_effect = true) const override;
|
||||
virtual bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) const;
|
||||
bool Cse(const FuncGraphPtr graph, const FuncGraphManagerPtr manager) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -67,53 +67,55 @@ BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
|
|||
return node_abs;
|
||||
}
|
||||
|
||||
bool CSE::BuildOrderGroupAndDoReplaceForOneGraph(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) const {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
std::vector<std::size_t> order_group;
|
||||
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
|
||||
std::unordered_map<AnfNodePtr, std::size_t> hashes;
|
||||
|
||||
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
|
||||
for (auto node : toposet) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (hashes.find(node) != hashes.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::size_t h = 0;
|
||||
if (node->isa<ValueNode>()) {
|
||||
ValueNodePtr value_node = node->cast<ValueNodePtr>();
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
size_t init = 0;
|
||||
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
|
||||
return hash_combine(hash, hashes[node_in]);
|
||||
});
|
||||
} else if (node->isa<Parameter>()) {
|
||||
h = node->hash();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknown node type";
|
||||
}
|
||||
|
||||
hashes[node] = h;
|
||||
if (groups.find(h) == groups.end()) {
|
||||
std::vector<AnfNodePtr> innervec({node});
|
||||
groups[h] = innervec;
|
||||
order_group.emplace_back(h);
|
||||
} else {
|
||||
groups[h].push_back(node);
|
||||
}
|
||||
}
|
||||
return DoReplace(manager, order_group, &groups);
|
||||
}
|
||||
|
||||
bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const {
|
||||
bool changed = false;
|
||||
for (FuncGraphPtr fg : manager->func_graphs()) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
std::vector<std::size_t> order_group;
|
||||
std::unordered_map<std::size_t, std::vector<AnfNodePtr>> groups;
|
||||
std::unordered_map<AnfNodePtr, std::size_t> hashes;
|
||||
|
||||
std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
|
||||
for (auto node : toposet) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (hashes.find(node) != hashes.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::size_t h = 0;
|
||||
if (node->isa<ValueNode>()) {
|
||||
ValueNodePtr value_node = node->cast<ValueNodePtr>();
|
||||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
size_t init = 0;
|
||||
h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
|
||||
return hash_combine(hash, hashes[node_in]);
|
||||
});
|
||||
} else if (node->isa<Parameter>()) {
|
||||
h = node->hash();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unknown node type";
|
||||
}
|
||||
|
||||
hashes[node] = h;
|
||||
if (groups.find(h) == groups.end()) {
|
||||
std::vector<AnfNodePtr> innervec({node});
|
||||
groups[h] = innervec;
|
||||
order_group.emplace_back(h);
|
||||
} else {
|
||||
groups[h].push_back(node);
|
||||
}
|
||||
}
|
||||
|
||||
changed = DoReplace(manager, order_group, &groups) || changed;
|
||||
changed = BuildOrderGroupAndDoReplaceForOneGraph(fg, manager) || changed;
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,10 @@ class CSE {
|
|||
|
||||
virtual bool CheckRandomEffect(const AnfNodePtr &main, const AnfNodePtr &node) const;
|
||||
|
||||
bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;
|
||||
virtual bool Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) const;
|
||||
|
||||
protected:
|
||||
bool BuildOrderGroupAndDoReplaceForOneGraph(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) const;
|
||||
|
||||
private:
|
||||
bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const;
|
||||
|
|
Loading…
Reference in New Issue