add pass to eliminate depend value

This commit is contained in:
BowenK 2020-06-19 14:45:48 +08:00
parent d638578d33
commit 8f29e7242f
6 changed files with 39 additions and 0 deletions

View File

@ -70,6 +70,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend);
// Env Item Eliminate
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);

View File

@ -48,6 +48,7 @@ class OptimizeIRPassLib {
SubstitutionPtr same_eliminate_;
SubstitutionPtr check_bprop_eliminate_;
SubstitutionPtr reset_defer_inline_;
SubstitutionPtr depend_value_elim_;
// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;

View File

@ -24,9 +24,11 @@
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/optimizer_caller.h"
#include "optimizer/irpass/prim_eliminate.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "ir/pattern_matcher.h"
namespace mindspore {
namespace opt {
@ -191,6 +193,17 @@ class ZeroLikeFillZero : public AnfVisitor {
AnfNodePtr y_{nullptr};
PrimitivePtr PrimFill_, PrimShape_, PrimDType_;
};
// {prim::kPrimDepend, X, ValueCond}->X
class DependValueElim : public OptimizerCaller {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
PatternNode<AnfNodePtr> x, cond;
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node)));
return nullptr;
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore

View File

@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.incorporate_env_getitem_,
irpass.incorporate_env_getitem_switch_,
irpass.new_env_get_item_,
irpass.depend_value_elim_,
});
opt::OptPassConfig a_3 = opt::OptPassConfig({
irpass.same_eliminate_,

View File

@ -257,6 +257,14 @@ TEST_F(TestOptLib, test_elim_transpose) {
ASSERT_TRUE(CheckOpt(before, after, patterns));
}
TEST_F(TestOptLib, test_elim_depend_value) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_depend_value", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_depend_value", "after");
auto patterns = std::vector<SubstitutionPtr>({irpass.depend_value_elim_});
ASSERT_TRUE(CheckOpt(before, after, patterns));
}
TEST_F(TestOptLib, test_elim_tile_multiply_one) {
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before");
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after");

View File

@ -494,6 +494,21 @@ def test_elim_transpose(tag):
return fns[tag]
def test_elim_depend_value(tag):
""" test_elim_depend_value """
fns = FnDict()
depend = P.Depend()
@fns
def before(x):
return depend(x, None)
@fns
def after(x):
return x
return fns[tag]
def test_elim_tile_multiply_one(tag):
""" test_elim_tile_multiply_one """