From b9b4a5e5f74d93d52803d20af475e5fc384f4950 Mon Sep 17 00:00:00 2001 From: dayschan Date: Fri, 8 Jan 2021 10:48:58 +0800 Subject: [PATCH] Add a restriction for getitem in basic_ops_fusion. this commit reverts the modification for basic_ops_fusion.cc in 8af78cd5c, the getitem should be fused with its all users. (no bug. but when the network is large, it works very slowly, this's a temporary solution) --- .../optimizer/graph_kernel/basic_ops_fusion.cc | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index 4bf455e0fe9..f1057685d62 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -66,23 +66,36 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode return EXCLUDE; } -// The GetItem node should be fused with its real input. +// The GetItem node should be fused with its real input and users. // If its real input is not in the fuse_list, the GetItem should be excluded. AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) { if (fused_op.empty()) return AnfNodePtrList(); std::set fused_op_set(fused_op.begin(), fused_op.end()); auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; }; + auto mng = fused_op[0]->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(mng); bool changed = true; while (changed) { changed = false; AnfNodePtrList remove_list; for (auto getitem : fused_op_set) { if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue; + // GetItem should be fused with its real input. auto prev_node = getitem->cast()->input(kRealInputNodeIndexInTupleGetItem); if (check_include(prev_node) == EXCLUDE) { remove_list.push_back(getitem); + break; + } + + // GetItem should be fused with its all users. + const auto &users = mng->node_users()[getitem]; + if (std::any_of(users.begin(), users.end(), [check_include](const std::pair &user) { + return check_include(user.first) == EXCLUDE; + })) { + remove_list = DeepLinkedGraphSearch(getitem, check_include); + break; } } if (!remove_list.empty()) {