forked from mindspore-Ecosystem/mindspore
!11093 【GraphKernel】Add a restriction for getitem in basic_ops_fusion
From: @dayschan Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_dou
This commit is contained in:
commit
5b751cfa4a
|
@ -66,23 +66,36 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
|
|||
return EXCLUDE;
|
||||
}
|
||||
|
||||
// The GetItem node should be fused with its real input.
|
||||
// The GetItem node should be fused with its real input and users.
|
||||
// If its real input is not in the fuse_list, the GetItem should be excluded.
|
||||
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
|
||||
if (fused_op.empty()) return AnfNodePtrList();
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
|
||||
|
||||
auto mng = fused_op[0]->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
AnfNodePtrList remove_list;
|
||||
for (auto getitem : fused_op_set) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
|
||||
|
||||
// GetItem should be fused with its real input.
|
||||
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (check_include(prev_node) == EXCLUDE) {
|
||||
remove_list.push_back(getitem);
|
||||
break;
|
||||
}
|
||||
|
||||
// GetItem should be fused with its all users.
|
||||
const auto &users = mng->node_users()[getitem];
|
||||
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
|
||||
return check_include(user.first) == EXCLUDE;
|
||||
})) {
|
||||
remove_list = DeepLinkedGraphSearch(getitem, check_include);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!remove_list.empty()) {
|
||||
|
|
Loading…
Reference in New Issue