!14113 incorporate tuple_getitem(switch()) if g1 or g2 is a specific kind of fprop funcgraph

From: @xychow
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-26 09:19:12 +08:00 committed by Gitee
commit b5a25045f8
1 changed files with 21 additions and 1 deletions

View File

@ -376,7 +376,10 @@ class IncorporateGetitemSwitch : public AnfVisitor {
}
auto tuple_getitem = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg)) {
// If exist env_getitem/env_setitem in this funcgraph or
// if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem;
if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) &&
!ExistEnvNodeInTupleItem(g2_)) {
return nullptr;
}
auto new_g1 = getitem_transform_(g1_, idx_);
@ -455,6 +458,23 @@ class IncorporateGetitemSwitch : public AnfVisitor {
});
}
static bool inline ExistEnvNodeInTupleItem(const FuncGraphPtr &fg) {
MS_EXCEPTION_IF_NULL(fg);
const auto &output = fg->output();
if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
return false;
}
const auto &cnode = output->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &node) {
auto sub_fg = GetValueNode<FuncGraphPtr>(node);
if (sub_fg != nullptr && ExistEnvNode(sub_fg)) {
return true;
}
return false;
});
}
int64_t idx_{-1};
AnfNodePtr switch_{nullptr}, x_{nullptr};
FuncGraphPtr g1_{nullptr}, g2_{nullptr};