!11093 【GraphKernel】Add a restriction for getitem in basic_ops_fusion

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
This commit is contained in:
mindspore-ci-bot 2021-01-11 16:57:59 +08:00 committed by Gitee
commit 5b751cfa4a
1 changed files with 14 additions and 1 deletions

View File

@ -66,23 +66,36 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
return EXCLUDE;
}
// The GetItem node should be fused with its real input.
// The GetItem node should be fused with its real input and users.
// If its real input is not in the fuse_list, the GetItem should be excluded.
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
if (fused_op.empty()) return AnfNodePtrList();
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
auto mng = fused_op[0]->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = true;
while (changed) {
changed = false;
AnfNodePtrList remove_list;
for (auto getitem : fused_op_set) {
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
// GetItem should be fused with its real input.
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (check_include(prev_node) == EXCLUDE) {
remove_list.push_back(getitem);
break;
}
// GetItem should be fused with its all users.
const auto &users = mng->node_users()[getitem];
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
return check_include(user.first) == EXCLUDE;
})) {
remove_list = DeepLinkedGraphSearch(getitem, check_include);
break;
}
}
if (!remove_list.empty()) {