separate irpass EnvGetItemEliminater and ItemTupleOrListEliminator

This commit is contained in:
huangbingjian 2021-03-30 19:30:13 +08:00
parent f47767b361
commit 134d5dfe4b
7 changed files with 101 additions and 96 deletions

View File

@ -70,9 +70,25 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
float_depend_g_call_ = MakeSubstitution(std::make_shared<FloatDependGCall>(), "float_depend_g_call", IsCNodeDup);
// ops eliminate
item_tuple_or_list_eliminate_ = MakeSubstitution(
std::make_shared<ItemTupleOrListEliminator>(), "item_tuple_or_list_eliminate",
tuple_list_get_item_eliminator_ =
MakeSubstitution(std::make_shared<TupleListGetitemEliminator>(), "tuple_list_get_item_eliminator",
{prim::kPrimTupleGetItem, prim::kPrimListGetItem});
tuple_list_get_item_const_eliminator_ =
MakeSubstitution(std::make_shared<TupleListGetitemConstEliminator>(), "tuple_list_get_item_const_eliminator",
{prim::kPrimTupleGetItem, prim::kPrimListGetItem});
tuple_list_set_item_eliminator_ =
MakeSubstitution(std::make_shared<TupleListSetitemEliminator>(), "tuple_list_set_item_eliminator",
{prim::kPrimTupleSetItem, prim::kPrimListSetItem});
tuple_list_get_set_item_eliminator_ =
MakeSubstitution(std::make_shared<TupleListGetSetitemEliminator>(), "tuple_list_get_set_item_eliminator",
{prim::kPrimTupleGetItem, prim::kPrimListGetItem});
tuple_list_get_item_depend_reorder_ =
MakeSubstitution(std::make_shared<TupleListGetitemDependReorder>(), "tuple_list_get_item_depend_reorder",
{prim::kPrimTupleGetItem, prim::kPrimListGetItem});
tuple_list_convert_item_index_to_positive_ = MakeSubstitution(
std::make_shared<TupleListConvertItemIndexToPositive>(), "tuple_list_convert_item_index_to_positive",
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
@ -99,7 +115,13 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Env Item Eliminate
env_get_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
new_env_get_item_ = MakeSubstitution(std::make_shared<NewEnvGetItem>(), "new_env_get_item", prim::kPrimEnvGetItem);
env_get_item_add_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetItemAddEliminater>(), "env_get_item_add_eliminate_", prim::kPrimEnvGetItem);
env_get_set_item_eliminate_ =
MakeSubstitution(std::make_shared<EnvGetSetItemEliminater>(), "env_get_set_item_eliminate", prim::kPrimEnvGetItem);
env_get_item_depend_swap_ =
MakeSubstitution(std::make_shared<EnvGetItemDependSwap>(), "env_get_item_depend_swap", prim::kPrimEnvGetItem);
incorporate_env_getitem_bypass_recursive_ =
MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),

View File

@ -39,7 +39,13 @@ class OptimizeIRPassLib {
SubstitutionPtr adjust_all_reduce_mul_add_;
SubstitutionPtr float_depend_g_call_;
// ops eliminate
SubstitutionPtr item_tuple_or_list_eliminate_;
SubstitutionPtr tuple_list_get_item_eliminator_;
SubstitutionPtr tuple_list_get_item_const_eliminator_;
SubstitutionPtr tuple_list_set_item_eliminator_;
SubstitutionPtr tuple_list_get_set_item_eliminator_;
SubstitutionPtr tuple_list_get_item_depend_reorder_;
SubstitutionPtr tuple_list_convert_item_index_to_positive_;
SubstitutionPtr tile_eliminate_;
SubstitutionPtr cast_eliminate_;
SubstitutionPtr reshape_eliminate_;
@ -57,7 +63,9 @@ class OptimizeIRPassLib {
// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;
SubstitutionPtr new_env_get_item_;
SubstitutionPtr env_get_item_add_eliminate_;
SubstitutionPtr env_get_set_item_eliminate_;
SubstitutionPtr env_get_item_depend_swap_;
SubstitutionPtr incorporate_env_getitem_;
SubstitutionPtr incorporate_env_getitem_bypass_recursive_;
SubstitutionPtr incorporate_env_getitem_switch_;

View File

@ -157,7 +157,7 @@ class EnvGetitemTransformACrossGraph {
} // namespace internal
// {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
class NewEnvGetItem : public AnfVisitor {
class EnvGetItemEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode c1, c2, y;
@ -170,10 +170,10 @@ class NewEnvGetItem : public AnfVisitor {
// {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
// {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}}
class AddEnvGetItem : public AnfVisitor {
class EnvGetItemAddEliminater : public AnfVisitor {
public:
AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
~AddEnvGetItem() override = default;
EnvGetItemAddEliminater() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
~EnvGetItemAddEliminater() override = default;
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false;
@ -211,7 +211,7 @@ class AddEnvGetItem : public AnfVisitor {
};
// {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z}
class EnvGetSetItem : public AnfVisitor {
class EnvGetSetItemEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
is_match_ = false;
@ -281,7 +281,7 @@ class EnvGetSetItem : public AnfVisitor {
// {prim::kPrimEnvGetitem, {prim::kPrimDepend, X1, X2}, item, dflt} ->
// {prim::kPrimDepend, {prim::kPrimEnvGetitem, X1, item, dflt}, X2}
class SwapEnvGetItemDepend : public OptimizerCaller {
class EnvGetItemDependSwap : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
@ -297,36 +297,6 @@ class SwapEnvGetItemDepend : public OptimizerCaller {
}
};
class EnvGetItemEliminater : public OptimizerCaller {
public:
EnvGetItemEliminater()
: new_env_get_item_(std::make_shared<NewEnvGetItem>()),
add_env_get_item_(std::make_shared<AddEnvGetItem>()),
env_get_set_item_(std::make_shared<EnvGetSetItem>()),
swap_env_get_item_depend_(std::make_shared<SwapEnvGetItemDepend>()) {
eliminaters_.emplace_back(new_env_get_item_);
eliminaters_.emplace_back(add_env_get_item_);
eliminaters_.emplace_back(env_get_set_item_);
eliminaters_.emplace_back(swap_env_get_item_depend_);
}
~EnvGetItemEliminater() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminater : eliminaters_) {
new_node = (*eliminater)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_, swap_env_get_item_depend_;
std::vector<OptimizerCallerPtr> eliminaters_{};
};
// {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
class IncorporateEnvGetitem : public AnfVisitor {
public:

View File

@ -38,7 +38,7 @@ namespace irpass {
// setitem([a, b, c, ...], -1, z) => setitem([a, b, c, ...], length - 1, z)
// {prim::kPrimTupleSetItem, T, N, Z}
// {prim::kPrimListSetItem, L, N, Z}
class ConvertItemIndexToPositive : public AnfVisitor {
class TupleListConvertItemIndexToPositive : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -96,7 +96,7 @@ class ConvertItemIndexToPositive : public AnfVisitor {
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
class GetitemEliminator : public AnfVisitor {
class TupleListGetitemEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -144,7 +144,7 @@ class GetitemEliminator : public AnfVisitor {
// (a, b, c, ...)[1] => b
// {prim::kPrimTupleGetItem, C1, C}
// {prim::kPrimListGetItem, C1, C}
class GetitemConstEliminator : public AnfVisitor {
class TupleListGetitemConstEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -195,7 +195,7 @@ class GetitemConstEliminator : public AnfVisitor {
// {prim::kPrimListSetItem, {prim::kPrimMakeList, a, b, c, ...}, 0, z} => {prim::kPrimMakeList, z, b, c, ...}
// {prim::kPrimTupleSetItem, (a, b, c, ...), 0, z} => {prim::kPrimMakeTuple, z, b, c, ...}
// {prim::kPrimListSetItem, [a, b, c, ...], 0, z} => {prim::kPrimMakeList, z, b, c, ...}
class SetitemEliminator : public AnfVisitor {
class TupleListSetitemEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -277,7 +277,7 @@ class SetitemEliminator : public AnfVisitor {
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
class GetSetitemEliminator : public AnfVisitor {
class TupleListGetSetitemEliminator : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -348,7 +348,7 @@ class GetSetitemEliminator : public AnfVisitor {
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
class GetitemDependReorder : public AnfVisitor {
class TupleListGetitemDependReorder : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
@ -405,41 +405,6 @@ class GetitemDependReorder : public AnfVisitor {
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
};
class ItemTupleOrListEliminator : public OptimizerCaller {
public:
ItemTupleOrListEliminator()
: get_item_eliminator_(std::make_shared<GetitemEliminator>()),
get_item_const_eliminator_(std::make_shared<GetitemConstEliminator>()),
set_item_eliminator_(std::make_shared<SetitemEliminator>()),
get_set_item_eliminator_(std::make_shared<GetSetitemEliminator>()),
get_item_depend_reorder_(std::make_shared<GetitemDependReorder>()),
convert_item_index_to_positive_(std::make_shared<ConvertItemIndexToPositive>()) {
eliminators_.emplace_back(get_item_eliminator_);
eliminators_.emplace_back(get_item_const_eliminator_);
eliminators_.emplace_back(set_item_eliminator_);
eliminators_.emplace_back(get_set_item_eliminator_);
eliminators_.emplace_back(get_item_depend_reorder_);
eliminators_.emplace_back(convert_item_index_to_positive_);
}
~ItemTupleOrListEliminator() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &eliminator : eliminators_) {
new_node = (*eliminator)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
OptimizerCallerPtr get_item_eliminator_, get_item_const_eliminator_, set_item_eliminator_, get_set_item_eliminator_,
get_item_depend_reorder_, convert_item_index_to_positive_;
std::vector<OptimizerCallerPtr> eliminators_{};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -112,8 +112,18 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.replace_applicator_,
// Miscellaneous
irpass.item_tuple_or_list_eliminate_,
irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_,
irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_,
irpass.tuple_list_convert_item_index_to_positive_,
irpass.env_get_item_eliminate_,
irpass.env_get_item_add_eliminate_,
irpass.env_get_set_item_eliminate_,
irpass.env_get_item_depend_swap_,
irpass.cast_eliminate_,
irpass.reshape_eliminate_,
irpass.reduce_eliminate_,
@ -146,7 +156,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_call_switch_,
irpass.incorporate_env_getitem_bypass_recursive_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.env_get_item_eliminate_,
irpass.depend_value_elim_,
irpass.all_reduce_const_elim_,
},
@ -218,7 +228,10 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig d_1 =
opt::OptPassConfig({// Safe inlining
irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_});
irpass.call_graph_tuple_transform_, irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_get_item_const_eliminator_, irpass.tuple_list_set_item_eliminator_,
irpass.tuple_list_get_set_item_eliminator_, irpass.tuple_list_get_item_depend_reorder_,
irpass.tuple_list_convert_item_index_to_positive_});
OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
@ -226,12 +239,30 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib
}
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig b_1 = opt::OptPassConfig(
{irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_,
irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_},
opt::OptPassConfig b_1 = opt::OptPassConfig({irpass.zero_like_fill_zero_,
irpass.tuple_list_get_item_eliminator_,
irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_,
irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_,
irpass.tuple_list_convert_item_index_to_positive_,
irpass.float_tuple_getitem_switch_,
irpass.reset_defer_inline_,
irpass.inline_,
irpass.updatestate_eliminater_,
irpass.load_eliminater_,
irpass.stopgrad_eliminater_,
irpass.special_op_eliminate_,
irpass.get_make_ref_eliminate_,
irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_switch_,
irpass.env_get_item_eliminate_,
irpass.env_get_item_add_eliminate_,
irpass.env_get_set_item_eliminate_,
irpass.env_get_item_depend_swap_,
irpass.incorporate_env_getitem_switch_layer_,
irpass.value_based_eliminate_,
irpass.receive_eliminate_},
false, true);
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.replace_refkey_by_param_,

View File

@ -15,9 +15,9 @@
import os
import sys
import json
import openpyxl as opx
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import openpyxl as opx
def parse_arguments():

View File

@ -355,7 +355,10 @@ TEST_F(TestOptLib, test_tuple_getitem) {
FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
after_2->set_output(value_node_2);
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
auto patterns = std::vector<SubstitutionPtr>(
{irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
@ -367,7 +370,10 @@ TEST_F(TestOptLib, test_tuple_setitem) {
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1");
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
auto patterns = std::vector<SubstitutionPtr>(
{irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
@ -379,7 +385,10 @@ TEST_F(TestOptLib, test_tuple_get_set_item) {
FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0");
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0");
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
auto patterns = std::vector<SubstitutionPtr>(
{irpass.tuple_list_get_item_eliminator_, irpass.tuple_list_get_item_const_eliminator_,
irpass.tuple_list_set_item_eliminator_, irpass.tuple_list_get_set_item_eliminator_,
irpass.tuple_list_get_item_depend_reorder_, irpass.tuple_list_convert_item_index_to_positive_});
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));