forked from mindspore-Ecosystem/mindspore
add pass to eliminate depend value
This commit is contained in:
parent
d638578d33
commit
8f29e7242f
|
@ -70,6 +70,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
same_eliminate_ = MakeSubstitution(SameEliminater(), "same_eliminate", prim::kPrimSameTypeShape);
|
||||
check_bprop_eliminate_ = MakeSubstitution(CheckBpropEliminater(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||
reset_defer_inline_ = MakeSubstitution(ResetDeferInline(), "reset_defer_inline", IsValueNode<FuncGraph>);
|
||||
depend_value_elim_ = MakeSubstitution(DependValueElim(), "depend_value_elim", prim::kPrimDepend);
|
||||
|
||||
// Env Item Eliminate
|
||||
env_get_item_eliminate_ = MakeSubstitution(EnvGetItemEliminater(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
|
||||
|
|
|
@ -48,6 +48,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr same_eliminate_;
|
||||
SubstitutionPtr check_bprop_eliminate_;
|
||||
SubstitutionPtr reset_defer_inline_;
|
||||
SubstitutionPtr depend_value_elim_;
|
||||
|
||||
// Env Item Eliminate
|
||||
SubstitutionPtr env_get_item_eliminate_;
|
||||
|
|
|
@ -24,9 +24,11 @@
|
|||
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "ir/optimizer_caller.h"
|
||||
#include "optimizer/irpass/prim_eliminate.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "operator/ops.h"
|
||||
#include "ir/pattern_matcher.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -191,6 +193,17 @@ class ZeroLikeFillZero : public AnfVisitor {
|
|||
AnfNodePtr y_{nullptr};
|
||||
PrimitivePtr PrimFill_, PrimShape_, PrimDType_;
|
||||
};
|
||||
|
||||
// {prim::kPrimDepend, X, ValueCond}->X
|
||||
class DependValueElim : public OptimizerCaller {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
PatternNode<AnfNodePtr> x, cond;
|
||||
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node)));
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -108,6 +108,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.incorporate_env_getitem_,
|
||||
irpass.incorporate_env_getitem_switch_,
|
||||
irpass.new_env_get_item_,
|
||||
irpass.depend_value_elim_,
|
||||
});
|
||||
opt::OptPassConfig a_3 = opt::OptPassConfig({
|
||||
irpass.same_eliminate_,
|
||||
|
|
|
@ -257,6 +257,14 @@ TEST_F(TestOptLib, test_elim_transpose) {
|
|||
ASSERT_TRUE(CheckOpt(before, after, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_elim_depend_value) {
|
||||
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_depend_value", "before");
|
||||
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_depend_value", "after");
|
||||
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.depend_value_elim_});
|
||||
ASSERT_TRUE(CheckOpt(before, after, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_elim_tile_multiply_one) {
|
||||
FuncGraphPtr before = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "before");
|
||||
FuncGraphPtr after = getPyFun.CallAndParseRet("test_elim_tile_multiply_one", "after");
|
||||
|
|
|
@ -494,6 +494,21 @@ def test_elim_transpose(tag):
|
|||
|
||||
return fns[tag]
|
||||
|
||||
def test_elim_depend_value(tag):
|
||||
""" test_elim_depend_value """
|
||||
fns = FnDict()
|
||||
depend = P.Depend()
|
||||
|
||||
@fns
|
||||
def before(x):
|
||||
return depend(x, None)
|
||||
|
||||
@fns
|
||||
def after(x):
|
||||
return x
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_elim_tile_multiply_one(tag):
|
||||
""" test_elim_tile_multiply_one """
|
||||
|
|
Loading…
Reference in New Issue