diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index db5b26de5ad..cb20d770533 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -70,9 +70,25 @@ OptimizeIRPassLib::OptimizeIRPassLib() { float_depend_g_call_ = MakeSubstitution(std::make_shared(), "float_depend_g_call", IsCNodeDup); // ops eliminate - item_tuple_or_list_eliminate_ = MakeSubstitution( - std::make_shared(), "item_tuple_or_list_eliminate", + tuple_list_get_item_eliminator_ = + MakeSubstitution(std::make_shared(), "tuple_list_get_item_eliminator", + {prim::kPrimTupleGetItem, prim::kPrimListGetItem}); + tuple_list_get_item_const_eliminator_ = + MakeSubstitution(std::make_shared(), "tuple_list_get_item_const_eliminator", + {prim::kPrimTupleGetItem, prim::kPrimListGetItem}); + tuple_list_set_item_eliminator_ = + MakeSubstitution(std::make_shared(), "tuple_list_set_item_eliminator", + {prim::kPrimTupleSetItem, prim::kPrimListSetItem}); + tuple_list_get_set_item_eliminator_ = + MakeSubstitution(std::make_shared(), "tuple_list_get_set_item_eliminator", + {prim::kPrimTupleGetItem, prim::kPrimListGetItem}); + tuple_list_get_item_depend_reorder_ = + MakeSubstitution(std::make_shared(), "tuple_list_get_item_depend_reorder", + {prim::kPrimTupleGetItem, prim::kPrimListGetItem}); + tuple_list_convert_item_index_to_positive_ = MakeSubstitution( + std::make_shared(), "tuple_list_convert_item_index_to_positive", {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem}); + tile_eliminate_ = MakeSubstitution(std::make_shared(), "tile_eliminate", prim::kPrimTile); cast_eliminate_ = MakeSubstitution(std::make_shared(), "cast_eliminate", prim::kPrimCast); reshape_eliminate_ = MakeSubstitution(std::make_shared(), "reshape_eliminate", prim::kPrimReshape); @@ -99,7 +115,13 @@ OptimizeIRPassLib::OptimizeIRPassLib() { // Env Item Eliminate env_get_item_eliminate_ = MakeSubstitution(std::make_shared(), "env_get_item_eliminate", prim::kPrimEnvGetItem); - new_env_get_item_ = MakeSubstitution(std::make_shared(), "new_env_get_item", prim::kPrimEnvGetItem); + env_get_item_add_eliminate_ = + MakeSubstitution(std::make_shared(), "env_get_item_add_eliminate_", prim::kPrimEnvGetItem); + env_get_set_item_eliminate_ = + MakeSubstitution(std::make_shared(), "env_get_set_item_eliminate", prim::kPrimEnvGetItem); + env_get_item_depend_swap_ = + MakeSubstitution(std::make_shared(), "env_get_item_depend_swap", prim::kPrimEnvGetItem); + incorporate_env_getitem_bypass_recursive_ = MakeSubstitution(std::make_shared(true), "incorporate_env_get_item", prim::kPrimEnvGetItem); incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared(), diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 7be50cb9eb3..92d4c3a7bbb 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -39,7 +39,13 @@ class OptimizeIRPassLib { SubstitutionPtr adjust_all_reduce_mul_add_; SubstitutionPtr float_depend_g_call_; // ops eliminate - SubstitutionPtr item_tuple_or_list_eliminate_; + SubstitutionPtr tuple_list_get_item_eliminator_; + SubstitutionPtr tuple_list_get_item_const_eliminator_; + SubstitutionPtr tuple_list_set_item_eliminator_; + SubstitutionPtr tuple_list_get_set_item_eliminator_; + SubstitutionPtr tuple_list_get_item_depend_reorder_; + SubstitutionPtr tuple_list_convert_item_index_to_positive_; + SubstitutionPtr tile_eliminate_; SubstitutionPtr cast_eliminate_; SubstitutionPtr reshape_eliminate_; @@ -57,7 +63,9 @@ class OptimizeIRPassLib { // Env Item Eliminate SubstitutionPtr env_get_item_eliminate_; - SubstitutionPtr new_env_get_item_; + SubstitutionPtr env_get_item_add_eliminate_; + SubstitutionPtr env_get_set_item_eliminate_; + SubstitutionPtr env_get_item_depend_swap_; SubstitutionPtr incorporate_env_getitem_; SubstitutionPtr incorporate_env_getitem_bypass_recursive_; SubstitutionPtr incorporate_env_getitem_switch_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h index 1bdc56e4d00..6992005b6c6 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -157,7 +157,7 @@ class EnvGetitemTransformACrossGraph { } // namespace internal // {prim::kPrimEnvGetItem, C1, C2, Y} -> Y -class NewEnvGetItem : public AnfVisitor { +class EnvGetItemEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode c1, c2, y; @@ -170,10 +170,10 @@ class NewEnvGetItem : public AnfVisitor { // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> // {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}} -class AddEnvGetItem : public AnfVisitor { +class EnvGetItemAddEliminater : public AnfVisitor { public: - AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} - ~AddEnvGetItem() override = default; + EnvGetItemAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} + ~EnvGetItemAddEliminater() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; @@ -211,7 +211,7 @@ class AddEnvGetItem : public AnfVisitor { }; // {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z} -class EnvGetSetItem : public AnfVisitor { +class EnvGetSetItemEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; @@ -281,7 +281,7 @@ class EnvGetSetItem : public AnfVisitor { // {prim::kPrimEnvGetitem, {prim::kPrimDepend, X1, X2}, item, dflt} -> // {prim::kPrimDepend, {prim::kPrimEnvGetitem, X1, item, dflt}, X2} -class SwapEnvGetItemDepend : public OptimizerCaller { +class EnvGetItemDependSwap : public OptimizerCaller { public: AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { @@ -297,36 +297,6 @@ class SwapEnvGetItemDepend : public OptimizerCaller { } }; -class EnvGetItemEliminater : public OptimizerCaller { - public: - EnvGetItemEliminater() - : new_env_get_item_(std::make_shared()), - add_env_get_item_(std::make_shared()), - env_get_set_item_(std::make_shared()), - swap_env_get_item_depend_(std::make_shared()) { - eliminaters_.emplace_back(new_env_get_item_); - eliminaters_.emplace_back(add_env_get_item_); - eliminaters_.emplace_back(env_get_set_item_); - eliminaters_.emplace_back(swap_env_get_item_depend_); - } - ~EnvGetItemEliminater() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminater : eliminaters_) { - new_node = (*eliminater)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_, swap_env_get_item_depend_; - std::vector eliminaters_{}; -}; - // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} class IncorporateEnvGetitem : public AnfVisitor { public: diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h index ad6657218d7..3f3e46be5cf 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/item_tuple_or_list_eliminate.h @@ -38,7 +38,7 @@ namespace irpass { // setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z) // {prim::kPrimTupleSetItem, T, N, Z} // {prim::kPrimListSetItem, L, N, Z} -class ConvertItemIndexToPositive : public AnfVisitor { +class TupleListConvertItemIndexToPositive : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -96,7 +96,7 @@ class ConvertItemIndexToPositive : public AnfVisitor { // (a, b, c, ...)[1] => b // {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C} // {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C} -class GetitemEliminator : public AnfVisitor { +class TupleListGetitemEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -144,7 +144,7 @@ class GetitemEliminator : public AnfVisitor { // (a, b, c, ...)[1] => b // {prim::kPrimTupleGetItem, C1, C} // {prim::kPrimListGetItem, C1, C} -class GetitemConstEliminator : public AnfVisitor { +class TupleListGetitemConstEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -195,7 +195,7 @@ class GetitemConstEliminator : public AnfVisitor { // {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...} // {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...} // {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...} -class SetitemEliminator : public AnfVisitor { +class TupleListSetitemEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -277,7 +277,7 @@ class SetitemEliminator : public AnfVisitor { // {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2} // {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2} -class GetSetitemEliminator : public AnfVisitor { +class TupleListGetSetitemEliminator : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -348,7 +348,7 @@ class GetSetitemEliminator : public AnfVisitor { // {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y} // {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} -> // {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y} -class GetitemDependReorder : public AnfVisitor { +class TupleListGetitemDependReorder : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); @@ -405,41 +405,6 @@ class GetitemDependReorder : public AnfVisitor { AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr}; }; -class ItemTupleOrListEliminator : public OptimizerCaller { - public: - ItemTupleOrListEliminator() - : get_item_eliminator_(std::make_shared()), - get_item_const_eliminator_(std::make_shared()), - set_item_eliminator_(std::make_shared()), - get_set_item_eliminator_(std::make_shared()), - get_item_depend_reorder_(std::make_shared()), - convert_item_index_to_positive_(std::make_shared()) { - eliminators_.emplace_back(get_item_eliminator_); - eliminators_.emplace_back(get_item_const_eliminator_); - eliminators_.emplace_back(set_item_eliminator_); - eliminators_.emplace_back(get_set_item_eliminator_); - eliminators_.emplace_back(get_item_depend_reorder_); - eliminators_.emplace_back(convert_item_index_to_positive_); - } - ~ItemTupleOrListEliminator() = default; - - AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { - AnfNodePtr new_node; - for (auto &eliminator : eliminators_) { - new_node = (*eliminator)(optimizer, node); - if (new_node != nullptr) { - return new_node; - } - } - return nullptr; - } - - private: - OptimizerCallerPtr get_item_eliminator_, get_item_const_eliminator_, set_item_eliminator_, get_set_item_eliminator_, - get_item_depend_reorder_, convert_item_index_to_positive_; - std::vector eliminators_{}; -}; - } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 7184cbd4717..96d1295d4c8 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -112,8 +112,18 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.replace_applicator_, // Miscellaneous - irpass.item_tuple_or_list_eliminate_, + irpass.tuple_list_get_item_eliminator_, + irpass.tuple_list_get_item_const_eliminator_, + irpass.tuple_list_set_item_eliminator_, + irpass.tuple_list_get_set_item_eliminator_, + irpass.tuple_list_get_item_depend_reorder_, + irpass.tuple_list_convert_item_index_to_positive_, + irpass.env_get_item_eliminate_, + irpass.env_get_item_add_eliminate_, + irpass.env_get_set_item_eliminate_, + irpass.env_get_item_depend_swap_, + irpass.cast_eliminate_, irpass.reshape_eliminate_, irpass.reduce_eliminate_, @@ -146,7 +156,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.incorporate_call_switch_, irpass.incorporate_env_getitem_bypass_recursive_, irpass.incorporate_env_getitem_switch_, - irpass.new_env_get_item_, + irpass.env_get_item_eliminate_, irpass.depend_value_elim_, irpass.all_reduce_const_elim_, }, @@ -218,7 +228,10 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining - irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_}); + irpass.call_graph_tuple_transform_, irpass.tuple_list_get_item_eliminator_, + irpass.tuple_list_get_item_const_eliminator_, irpass.tuple_list_set_item_eliminator_, + irpass.tuple_list_get_set_item_eliminator_, irpass.tuple_list_get_item_depend_reorder_, + irpass.tuple_list_convert_item_index_to_positive_}); OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}}); @@ -226,13 +239,31 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib } OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { - opt::OptPassConfig b_1 = opt::OptPassConfig( - {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, - irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, - irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, - irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}, - false, true); + opt::OptPassConfig b_1 = opt::OptPassConfig({irpass.zero_like_fill_zero_, + irpass.tuple_list_get_item_eliminator_, + irpass.tuple_list_get_item_const_eliminator_, + irpass.tuple_list_set_item_eliminator_, + irpass.tuple_list_get_set_item_eliminator_, + irpass.tuple_list_get_item_depend_reorder_, + irpass.tuple_list_convert_item_index_to_positive_, + irpass.float_tuple_getitem_switch_, + irpass.reset_defer_inline_, + irpass.inline_, + irpass.updatestate_eliminater_, + irpass.load_eliminater_, + irpass.stopgrad_eliminater_, + irpass.special_op_eliminate_, + irpass.get_make_ref_eliminate_, + irpass.incorporate_env_getitem_, + irpass.incorporate_env_getitem_switch_, + irpass.env_get_item_eliminate_, + irpass.env_get_item_add_eliminate_, + irpass.env_get_set_item_eliminate_, + irpass.env_get_item_depend_swap_, + irpass.incorporate_env_getitem_switch_layer_, + irpass.value_based_eliminate_, + irpass.receive_eliminate_}, + false, true); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, irpass.make_ref_eliminate_, diff --git a/tests/perf_test/mind_expression_perf/process_data.py b/tests/perf_test/mind_expression_perf/process_data.py index 89a968033c0..851e7a1c88c 100644 --- a/tests/perf_test/mind_expression_perf/process_data.py +++ b/tests/perf_test/mind_expression_perf/process_data.py @@ -15,9 +15,9 @@ import os import sys import json -import openpyxl as opx import matplotlib.ticker as ticker import matplotlib.pyplot as plt +import openpyxl as opx def parse_arguments(): diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 4f0d3fc6bba..8fb86f3eacf 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -355,7 +355,10 @@ TEST_F(TestOptLib, test_tuple_getitem) { FuncGraphPtr after_2 = std::make_shared(); after_2->set_output(value_node_2); - auto patterns = std::vector({irpass.item_tuple_or_list_eliminate_}); + auto patterns = std::vector( + {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_, + irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_, + irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_}); ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns)); ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns)); @@ -367,7 +370,10 @@ TEST_F(TestOptLib, test_tuple_setitem) { FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1"); - auto patterns = std::vector({irpass.item_tuple_or_list_eliminate_}); + auto patterns = std::vector( + {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_, + irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_, + irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_}); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns)); @@ -379,7 +385,10 @@ TEST_F(TestOptLib, test_tuple_get_set_item) { FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0"); FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0"); - auto patterns = std::vector({irpass.item_tuple_or_list_eliminate_}); + auto patterns = std::vector( + {irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_, + irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_, + irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_}); ASSERT_TRUE(CheckOpt(before_0, after_0, patterns)); ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));