Fuse composite ops separated by GetItem nodes

This commit is contained in:
dayschan 2020-11-13 11:55:11 +08:00
parent e2e532dec3
commit 8e6d92eac9
1 changed files with 81 additions and 11 deletions

View File

@ -36,6 +36,7 @@
#include "debug/anf_ir_dump.h"
#include "ir/func_graph_cloner.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/pass/getitem_tuple.h"
namespace mindspore {
namespace opt {
@ -60,20 +61,20 @@ std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, co
}
} // namespace
bool IsFuse(const AnfNodePtr &node) {
// composite fuse composite op
if (AnfAlgo::IsGraphKernel(node)) {
return true;
}
return IsBasicFuseOp(node);
}
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
bool is_fusable = IsFuse(node);
return is_fusable ? FOLLOW : EXCLUDE;
if (AnfAlgo::IsGraphKernel(node) || IsBasicFuseOp(node)) {
return FOLLOW;
}
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (AnfAlgo::IsGraphKernel(prev_node)) {
return FOLLOW;
}
}
return EXCLUDE;
}
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
@ -185,6 +186,60 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
return res;
}
// The GetItem node should be fused with its real input and users.
// If its real input is not in the fuse_list, the GetItem should be excluded.
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
if (fused_op.empty()) return AnfNodePtrList();
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
auto mng = fused_op[0]->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = true;
while (changed) {
changed = false;
AnfNodePtrList remove_list;
for (auto node : fused_op_set) {
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) continue;
// GetItem should be fused with its real input.
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (check_include(prev_node) == EXCLUDE) {
remove_list.push_back(node);
break;
}
// GetItem should be fused with its all users.
auto &users = mng->node_users()[node];
bool outside_user_found = false;
for (auto iter = users.begin(); iter != users.end(); ++iter) {
if (check_include(iter->first) == EXCLUDE) {
outside_user_found = true;
break;
}
}
if (outside_user_found) {
remove_list = DeepUsersSearch(node, check_include, mng);
break;
}
}
if (!remove_list.empty()) {
for (auto node : remove_list) {
fused_op_set.erase(node);
}
changed = true;
}
}
// keep the original order of fused_op.
AnfNodePtrList result;
for (auto node : fused_op) {
if (fused_op_set.count(node)) {
result.push_back(node);
}
}
return result;
}
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
if (lst->size() < 2) {
return;
@ -254,6 +309,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes);
}
used_nodes = RemoveWildGetitem(used_nodes);
TopoSortForNodeList(&used_nodes);
return used_nodes;
}
@ -288,8 +344,22 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph)
return changed;
}
void EliminateGetItem(const FuncGraphPtr &func_graph) {
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
if (AnfAlgo::IsGraphKernel(node)) {
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node));
}
}
}
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) {
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
auto changed = FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
if (changed) {
EliminateGetItem(func_graph);
}
return changed;
}
} // namespace opt
} // namespace mindspore