forked from mindspore-Ecosystem/mindspore
Fuse composite ops separated by GetItem nodes
This commit is contained in:
parent
e2e532dec3
commit
8e6d92eac9
|
@ -36,6 +36,7 @@
|
|||
#include "debug/anf_ir_dump.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -60,20 +61,20 @@ std::vector<AnfNodePtr> DeepUsersSearch(const std::vector<AnfNodePtr> &roots, co
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool IsFuse(const AnfNodePtr &node) {
|
||||
// composite fuse composite op
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return true;
|
||||
}
|
||||
return IsBasicFuseOp(node);
|
||||
}
|
||||
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
}
|
||||
bool is_fusable = IsFuse(node);
|
||||
return is_fusable ? FOLLOW : EXCLUDE;
|
||||
if (AnfAlgo::IsGraphKernel(node) || IsBasicFuseOp(node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (AnfAlgo::IsGraphKernel(prev_node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
}
|
||||
return EXCLUDE;
|
||||
}
|
||||
|
||||
IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
|
@ -185,6 +186,60 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
|
|||
return res;
|
||||
}
|
||||
|
||||
// The GetItem node should be fused with its real input and users.
|
||||
// If its real input is not in the fuse_list, the GetItem should be excluded.
|
||||
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
|
||||
if (fused_op.empty()) return AnfNodePtrList();
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
|
||||
|
||||
auto mng = fused_op[0]->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
AnfNodePtrList remove_list;
|
||||
for (auto node : fused_op_set) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) continue;
|
||||
// GetItem should be fused with its real input.
|
||||
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (check_include(prev_node) == EXCLUDE) {
|
||||
remove_list.push_back(node);
|
||||
break;
|
||||
}
|
||||
|
||||
// GetItem should be fused with its all users.
|
||||
auto &users = mng->node_users()[node];
|
||||
bool outside_user_found = false;
|
||||
for (auto iter = users.begin(); iter != users.end(); ++iter) {
|
||||
if (check_include(iter->first) == EXCLUDE) {
|
||||
outside_user_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (outside_user_found) {
|
||||
remove_list = DeepUsersSearch(node, check_include, mng);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!remove_list.empty()) {
|
||||
for (auto node : remove_list) {
|
||||
fused_op_set.erase(node);
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
// keep the original order of fused_op.
|
||||
AnfNodePtrList result;
|
||||
for (auto node : fused_op) {
|
||||
if (fused_op_set.count(node)) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
||||
if (lst->size() < 2) {
|
||||
return;
|
||||
|
@ -254,6 +309,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
|||
if (used_nodes.size() > 1) {
|
||||
used_nodes = RemoveCircle(used_nodes);
|
||||
}
|
||||
used_nodes = RemoveWildGetitem(used_nodes);
|
||||
TopoSortForNodeList(&used_nodes);
|
||||
return used_nodes;
|
||||
}
|
||||
|
@ -288,8 +344,22 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
return changed;
|
||||
}
|
||||
|
||||
void EliminateGetItem(const FuncGraphPtr &func_graph) {
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool CompositeOpsFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
return FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
|
||||
auto changed = FuseCompositeOps(std::dynamic_pointer_cast<session::KernelGraph>(func_graph));
|
||||
if (changed) {
|
||||
EliminateGetItem(func_graph);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue