diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index 2f4a8628224..cdfd02f7a7a 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -30,6 +30,8 @@ #include "ir/func_graph_cloner.h" #include "ir/tensor.h" #include "frontend/operator/ops.h" +#include "abstract/abstract_value.h" +#include "utils/utils.h" namespace mindspore { namespace opt { @@ -170,6 +172,7 @@ class InlinerBase : public AnfVisitor { } } // Or, just make a clone for not single used fg. + MS_LOG(INFO) << "Run InlineClone in inline pass, subgraph number may increase."; return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); } @@ -279,6 +282,30 @@ class InlinerBase : public AnfVisitor { return node->func_graph()->NewCNode(node_inputs); } + bool CheckSwitchBranchAbstract(const AbstractBasePtr &branch_abstract) { + if (branch_abstract != nullptr && branch_abstract->isa()) { + auto branch_abstract_value = branch_abstract->GetValueTrack(); + MS_EXCEPTION_IF_NULL(branch_abstract_value); + auto branch_abstract_value_string_imm = branch_abstract_value->cast(); + if (branch_abstract_value_string_imm != nullptr) { + auto branch_abstract_value_string_imm_value = branch_abstract_value_string_imm->value(); + return branch_abstract_value_string_imm_value == kDeadNodeName || + branch_abstract_value_string_imm_value == kPolyNodeName; + } + } + return false; + } + + bool CheckSwitchInputs(const std::vector &sw_inputs) { + auto true_branch_abstract = sw_inputs[kSwitchTrueKernelGraphIndex]->abstract(); + auto false_branch_abstract = sw_inputs[kSwitchFalseKernelGraphIndex]->abstract(); + // When branch has dead node or poly node, do not perform inline. + if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) { + return true; + } + return !sw_inputs[1]->isa() || IsValueNode(sw_inputs[1]); + } + // This is a try-best algorithm to find a graph which may generate branch call. // It does not handle high-order function call. For high-orderer call branch, it still may be inlined. bool GraphHasBranch(FuncGraphPtr fg) { @@ -293,7 +320,7 @@ class InlinerBase : public AnfVisitor { if (sw_inputs.size() != 4) { MS_LOG(EXCEPTION) << "switch inputs should be 4"; } - if (!sw_inputs[1]->isa() || IsValueNode(sw_inputs[1])) { + if (CheckSwitchInputs(sw_inputs)) { has_branch = true; break; } diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h index 55d9a22e0eb..a4717c0ee74 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/parameter_eliminate.h @@ -41,18 +41,18 @@ class ParameterEliminator { MS_EXCEPTION_IF_NULL(manager); bool changes = false; while (true) { - auto tr = manager->Transact(); const auto &[fg, callers] = SearchFuncGraphCallers(func_graph); if (fg == nullptr) { break; } - const auto &erase_indexes = EraseUnusedParameters(fg, &tr); + auto manager = fg->manager(); + MS_EXCEPTION_IF_NULL(manager); + const auto &erase_indexes = EraseUnusedParameters(fg, manager); for (auto caller : callers) { // Erase the corresponding args. - EraseArgs(caller, erase_indexes, &tr); + EraseArgs(caller, erase_indexes, manager); } changes = true; - tr.Commit(); } return changes; } @@ -99,7 +99,7 @@ class ParameterEliminator { return {nullptr, {}}; } - static std::unordered_set EraseUnusedParameters(const FuncGraphPtr &fg, FuncGraphTransaction *tr) { + static std::unordered_set EraseUnusedParameters(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(fg->manager()); const auto &manager_node_users = fg->manager()->node_users(); const auto ¶meters = fg->parameters(); @@ -122,12 +122,12 @@ class ParameterEliminator { MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i; } } - tr->SetParameters(fg, new_parameters); + manager->SetParameters(fg, new_parameters); return unused_parameter_indexes; } static void EraseArgs(const CNodePtr &caller, const std::unordered_set &unused_parameter_indexes, - FuncGraphTransaction *tr) { + const FuncGraphManagerPtr &manager) { std::vector new_args = {caller->inputs()[0]}; for (size_t i = 0; i < caller->inputs().size() - 1; i++) { if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) { @@ -139,7 +139,7 @@ class ParameterEliminator { TraceGuard trace_guard(std::make_shared(caller->debug_info())); auto new_caller = caller->func_graph()->NewCNode(new_args); new_caller->set_abstract(caller->abstract()); - tr->Replace(caller, new_caller); + manager->Replace(caller, new_caller); } }; } // namespace irpass diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 6e32b466e46..42b7bbe8b2d 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -334,7 +334,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater()); // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases(). - OptPassGroupMap map_a({{"a_1", a_1}, + OptPassGroupMap map_a({{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})}, + {"a_1", a_1}, {"updatestate_depend_eliminate", updatestate_depend_eliminate}, {"updatestate_assign_eliminate", updatestate_assign_eliminate}, {"updatestate_loads_eliminate", updatestate_loads_eliminate}, @@ -360,7 +361,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) { auto opt_a = GetOptPassesA(irpass); - constexpr auto a1_a2_len = 6; + constexpr auto a1_a2_len = 7; OptPassGroupMap a1_a2(opt_a.begin(), opt_a.begin() + a1_a2_len); return a1_a2; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index abe4755ce8e..feb91e77fa3 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -24,6 +24,7 @@ #include "frontend/operator/composite/do_signature.h" #include "abstract/abstract_function.h" #include "abstract/utils.h" +#include "utils/utils.h" #include "ir/graph_utils.h" #include "utils/log_adapter.h" #include "debug/trace.h" @@ -396,8 +397,8 @@ AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) } namespace { -const StringImmPtr kDeadNode = std::make_shared("Dead Node"); -const StringImmPtr kPolyNode = std::make_shared("Poly Node"); +const StringImmPtr kDeadNode = std::make_shared(kDeadNodeName); +const StringImmPtr kPolyNode = std::make_shared(kPolyNodeName); inline bool CanSpecializeNode(const AnfNodePtr &node) { if (IsValueNode(node) || IsValueNode(node) || IsValueNode(node)) { diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index e157f4c0f51..20333448cae 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -213,6 +213,8 @@ constexpr auto kFusedCastAdamWeightDecayName = "FusedCastAdamWeightDecay"; constexpr auto kFusedAdamName = "FusedAdam"; constexpr auto kFusedSparseAdamName = "FusedSparseAdam"; constexpr auto kFusedMatMulBiasAddName = "FusedMatMulBiasAdd"; +constexpr auto kDeadNodeName = "DeadNode"; +constexpr auto kPolyNodeName = "PolyNode"; constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2"; constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2"; constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl"; diff --git a/tests/st/control/inner/test_112_if_after_if_in_for.py b/tests/st/control/inner/test_112_if_after_if_in_for.py index 70a8e49a063..9cd6be5dce8 100644 --- a/tests/st/control/inner/test_112_if_after_if_in_for.py +++ b/tests/st/control/inner/test_112_if_after_if_in_for.py @@ -133,8 +133,8 @@ def control_flow_if_after_if_in_for(input_net, x, expect1, expect2): assert graph_backward_res == expect2 -@pytest.mark.skip(reason="ME EvalCNode error") -@pytest.mark.level1 +# @pytest.mark.skip(reason="ME EvalCNode error") +@pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training