diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h index 3092944e40d..b6b04777ea3 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/incorporate_getitem.h @@ -376,7 +376,10 @@ class IncorporateGetitemSwitch : public AnfVisitor { } auto tuple_getitem = node->cast(); MS_EXCEPTION_IF_NULL(tuple_getitem); - if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg)) { + // If exist env_getitem/env_setitem in this funcgraph or + // if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem; + if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) && + !ExistEnvNodeInTupleItem(g2_)) { return nullptr; } auto new_g1 = getitem_transform_(g1_, idx_); @@ -455,6 +458,23 @@ class IncorporateGetitemSwitch : public AnfVisitor { }); } + static bool inline ExistEnvNodeInTupleItem(const FuncGraphPtr &fg) { + MS_EXCEPTION_IF_NULL(fg); + const auto &output = fg->output(); + if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { + return false; + } + const auto &cnode = output->cast(); + const auto &inputs = cnode->inputs(); + return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &node) { + auto sub_fg = GetValueNode(node); + if (sub_fg != nullptr && ExistEnvNode(sub_fg)) { + return true; + } + return false; + }); + } + int64_t idx_{-1}; AnfNodePtr switch_{nullptr}, x_{nullptr}; FuncGraphPtr g1_{nullptr}, g2_{nullptr};