From 8e6d92eac904ebb374da1e7307e48397194d4f81 Mon Sep 17 00:00:00 2001 From: dayschan Date: Fri, 13 Nov 2020 11:55:11 +0800 Subject: [PATCH] Fuse composite ops separated by GetItem nodes --- .../graph_kernel/composite_ops_fusion.cc | 92 ++++++++++++++++--- 1 file changed, 81 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc index 2914dd09c8b..1dc017ba1c1 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/composite_ops_fusion.cc @@ -36,6 +36,7 @@ #include "debug/anf_ir_dump.h" #include "ir/func_graph_cloner.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h" +#include "backend/optimizer/pass/getitem_tuple.h" namespace mindspore { namespace opt { @@ -60,20 +61,20 @@ std::vector DeepUsersSearch(const std::vector &roots, co } } // namespace -bool IsFuse(const AnfNodePtr &node) { - // composite fuse composite op - if (AnfAlgo::IsGraphKernel(node)) { - return true; - } - return IsBasicFuseOp(node); -} - IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { if (cur_node == node) { return FOLLOW; } - bool is_fusable = IsFuse(node); - return is_fusable ? FOLLOW : EXCLUDE; + if (AnfAlgo::IsGraphKernel(node) || IsBasicFuseOp(node)) { + return FOLLOW; + } + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + auto prev_node = node->cast()->input(kRealInputNodeIndexInTupleGetItem); + if (AnfAlgo::IsGraphKernel(prev_node)) { + return FOLLOW; + } + } + return EXCLUDE; } IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) { @@ -185,6 +186,60 @@ std::vector RemoveCircle(const std::vector &fused_op, bo return res; } +// The GetItem node should be fused with its real input and users. +// If its real input is not in the fuse_list, the GetItem should be excluded. +AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { + if (fused_op.empty()) return AnfNodePtrList(); + std::set fused_op_set(fused_op.begin(), fused_op.end()); + auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; + + auto mng = fused_op[0]->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(mng); + bool changed = true; + while (changed) { + changed = false; + AnfNodePtrList remove_list; + for (auto node : fused_op_set) { + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) continue; + // GetItem should be fused with its real input. + auto prev_node = node->cast()->input(kRealInputNodeIndexInTupleGetItem); + if (check_include(prev_node) == EXCLUDE) { + remove_list.push_back(node); + break; + } + + // GetItem should be fused with its all users. + auto &users = mng->node_users()[node]; + bool outside_user_found = false; + for (auto iter = users.begin(); iter != users.end(); ++iter) { + if (check_include(iter->first) == EXCLUDE) { + outside_user_found = true; + break; + } + } + if (outside_user_found) { + remove_list = DeepUsersSearch(node, check_include, mng); + break; + } + } + if (!remove_list.empty()) { + for (auto node : remove_list) { + fused_op_set.erase(node); + } + changed = true; + } + } + + // keep the original order of fused_op. + AnfNodePtrList result; + for (auto node : fused_op) { + if (fused_op_set.count(node)) { + result.push_back(node); + } + } + return result; +} + void TopoSortForNodeList(std::vector *lst) { if (lst->size() < 2) { return; @@ -254,6 +309,7 @@ std::vector FindFuseCNodes(const CNodePtr &cnode) { if (used_nodes.size() > 1) { used_nodes = RemoveCircle(used_nodes); } + used_nodes = RemoveWildGetitem(used_nodes); TopoSortForNodeList(&used_nodes); return used_nodes; } @@ -288,8 +344,22 @@ bool FuseCompositeOps(const std::shared_ptr &kernel_graph) return changed; } +void EliminateGetItem(const FuncGraphPtr &func_graph) { + std::shared_ptr eliminate_getitem_pass = std::make_shared(); + auto todos = TopoSort(func_graph->get_return()); + for (auto node : todos) { + if (AnfAlgo::IsGraphKernel(node)) { + eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node)); + } + } +} + bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) { - return FuseCompositeOps(std::dynamic_pointer_cast(func_graph)); + auto changed = FuseCompositeOps(std::dynamic_pointer_cast(func_graph)); + if (changed) { + EliminateGetItem(func_graph); + } + return changed; } } // namespace opt } // namespace mindspore