diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc index 8e98f25cb1b..ca6b0a1bc03 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc @@ -122,7 +122,7 @@ const AnfNodePtr BatchNorm2BNInfer::Process(const FuncGraphPtr &graph, const Anf return nullptr; } auto bn_infer = CreateBNInfer(graph, batchnorm, node); - TransferDepend(batchnorm, graph, bn_infer); + TransferDependOrUpdateState(batchnorm, graph, bn_infer); return bn_infer; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc index 117c4217c93..2a88d6fce1c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc @@ -125,7 +125,7 @@ const AnfNodePtr BatchNormGrad2BNInferGrad::Process(const FuncGraphPtr &graph, c return nullptr; } auto bn_infer_grad = CreateBNInferGrad(graph, batchnorm_grad, node); - TransferDepend(batchnorm_grad, graph, bn_infer_grad); + TransferDependOrUpdateState(batchnorm_grad, graph, bn_infer_grad); return bn_infer_grad; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index a59499da83d..a07f9e023b1 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -916,21 +916,34 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { return new_value_node; } -void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { +void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node) { MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); // Find BatchNorm's output which is a Depend or UpdateState. - for (const auto &node_index : manager->node_users()[old_node]) { + auto node_users = manager->node_users()[old_node]; + for (const auto &node_index : node_users) { AnfNodePtr output = node_index.first; - size_t index = IntToSize(node_index.second); MS_EXCEPTION_IF_NULL(output); if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { - auto depend = output->cast(); - MS_EXCEPTION_IF_NULL(depend); - depend->set_input(index, new_node); + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto inputs = output_cnode->inputs(); + std::vector new_inputs{output_cnode->input(0)}; + for (size_t i = 1; i < inputs.size(); i++) { + auto input = inputs[i]; + if (input == old_node) { + new_inputs.emplace_back(new_node); + } else { + new_inputs.emplace_back(input); + } + } + auto new_output = graph->NewCNode(new_inputs); + new_output->set_abstract(output->abstract()); + new_output->set_scope(output->scope()); + manager->Replace(output, new_output); } } } diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 88537b50d8c..e298b4c1192 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -213,8 +213,8 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set &suppor // Create a new value node of func graph,not kernel graph ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); -// Transfer depend to the new node -void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); +// Transfer depend or updatestate to the new node +void TransferDependOrUpdateState(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc index 41f65335f56..53f175447d2 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/order_enforce.cc @@ -76,15 +76,17 @@ class OrderEnforcer { } } - bool CheckMakeTupleHaveLoad(const CNodePtr &cnode) { + std::unordered_set CheckMakeTupleHaveLoad(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::unordered_set loads; auto inputs = cnode->inputs(); for (size_t index = 1; index < inputs.size(); index++) { auto input = cnode->input(index); if (IsPrimitiveCNode(input, prim::kPrimLoad)) { - return true; + loads.insert(input); } } - return false; + return loads; } std::vector FindUpdateStateUsers(const CNodePtr &cnode) { @@ -155,23 +157,31 @@ class OrderEnforcer { // u3 = UpdateState(u', maketuple2, addn) # need put addn or other-op into u3 inputs // assign = Assign(para2, inputs, u3) void HandleMakeTupleUsers(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); auto maketuple = node->cast(); MS_EXCEPTION_IF_NULL(maketuple); - if (CheckMakeTupleHaveLoad(maketuple)) { + std::unordered_set loads = CheckMakeTupleHaveLoad(maketuple); + if (!loads.empty()) { auto update_state = FindLastUpdateState(maketuple); if (update_state != nullptr) { std::unordered_set maketuple_users = GetSpecialOperatorRealUsers(maketuple); - std::unordered_set no_push_maketuple_users; + std::unordered_set no_push_all_users; // Push and Pull at the end of the execution order, // In order to ensure push and pull operator cut into the same graph, do not put push operator into updatestate for (auto maketuple_user : maketuple_users) { if (!IsPrimitiveCNode(maketuple_user, prim::kPrimPush)) { - no_push_maketuple_users.insert(maketuple_user); + no_push_all_users.insert(maketuple_user); + } + } + for (auto load : loads) { + std::unordered_set load_users = GetSpecialOperatorRealUsers(load); + for (auto load_user : load_users) { + no_push_all_users.insert(load_user); } } auto update_state_cnode = update_state->cast(); MS_EXCEPTION_IF_NULL(update_state_cnode); - AddInputEdges(update_state_cnode, no_push_maketuple_users); + AddInputEdges(update_state_cnode, no_push_all_users); } } } @@ -265,6 +275,8 @@ class OrderEnforcer { // Add load users as input edges of the update_state node. void AddInputEdges(const CNodePtr &update_state, const std::unordered_set &load_users) { auto sorted_load_users = SortLoadUsers(load_users); + auto inputs = update_state->inputs(); + size_t origin_size = inputs.size(); for (auto &load_user : sorted_load_users) { if (IsPrimitiveCNode(load_user, prim::kPrimMakeTuple) || IsPrimitiveCNode(load_user, prim::kPrimUpdateState)) { continue; @@ -272,10 +284,16 @@ class OrderEnforcer { if (!IsDependOn(load_user, update_state)) { processed_nodes_.insert(load_user); if (!IsInUpdateState(load_user, update_state)) { - manager_->AddEdge(update_state, load_user); + inputs.emplace_back(load_user); } } } + if (inputs.size() > origin_size) { + auto new_update_state = func_graph_->NewCNode(inputs); + new_update_state->set_abstract(update_state->abstract()); + new_update_state->set_scope(update_state->scope()); + manager_->Replace(update_state, new_update_state); + } } // Sort load users by their topo sort order. diff --git a/tests/st/control/inner/test_110_if_after_if_in_if.py b/tests/st/control/inner/test_110_if_after_if_in_if.py index 6d3ed59a1e4..a0e3ad893ad 100644 --- a/tests/st/control/inner/test_110_if_after_if_in_if.py +++ b/tests/st/control/inner/test_110_if_after_if_in_if.py @@ -157,7 +157,9 @@ def test_if_after_if_in_if(): control_flow_if_after_if_in_if(IfAfterIfInIfNet, x) -@pytest.mark.skip(reason="not supported side effect") +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_if_after_if_in_if_01(): x = Tensor(2, mstype.int32) control_flow_if_after_if_in_if(IfAfterIfInIfNet1, x)