forked from mindspore-Ecosystem/mindspore
add 3 pattern for lamb_next_mv_rule fusion pass
This commit is contained in:
parent
c8f69f5db2
commit
a42dd21430
|
@ -104,6 +104,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
|||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond4>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond1>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond2>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond3>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRuleCond4>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
|
||||
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());
|
||||
|
|
|
@ -116,9 +116,116 @@ const AnfNodePtr LambNextMVRule::Process(const FuncGraphPtr &func_graph, const A
|
|||
return CreateLambNextMVNode(func_graph, old_pattern_outputs, equiv);
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVRuleCond1::DefinePattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
|
||||
auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
|
||||
auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
|
||||
auto mul2 = VectorRef({prim::kPrimMul, mul2_x_, input1_});
|
||||
auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
|
||||
auto mul4 = VectorRef({prim::kPrimMul, mul4_x_, input6_});
|
||||
auto add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
auto add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
|
||||
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
||||
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
||||
|
||||
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1});
|
||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
||||
|
||||
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
||||
}
|
||||
|
||||
BaseRef LambNextMVRuleCond1::DefineAnotherPattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
// Two patterns share: real_div0, real_div1, add2_y_
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, add2_y_, sqrt1});
|
||||
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
||||
return real_div4;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVRuleCond2::DefinePattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
|
||||
auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_});
|
||||
auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_});
|
||||
auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_});
|
||||
auto mul3 = VectorRef({prim::kPrimMul, mul3_sub1_, input0_});
|
||||
auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_});
|
||||
auto add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
auto add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
|
||||
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
||||
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
||||
|
||||
auto add2 = VectorRef({prim::kPrimTensorAdd, add2_y_, real_div1});
|
||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
||||
|
||||
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
||||
}
|
||||
|
||||
BaseRef LambNextMVRuleCond2::DefineAnotherPattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
// Two patterns share: real_div0, real_div1, add2_y_
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
|
||||
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
||||
return real_div4;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVRuleCond3::DefinePattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
|
||||
auto mul0 = VectorRef({prim::kPrimMul, input4_, mul0_x_});
|
||||
auto mul1 = VectorRef({prim::kPrimMul, input3_, mul1_sub_});
|
||||
auto mul2 = VectorRef({prim::kPrimMul, input1_, mul2_x_});
|
||||
auto mul3 = VectorRef({prim::kPrimMul, input0_, mul3_sub1_});
|
||||
auto mul4 = VectorRef({prim::kPrimMul, input6_, mul4_x_});
|
||||
auto add0 = VectorRef({add0_var_, mul0, mul1});
|
||||
auto add1 = VectorRef({add1_var_, mul2, mul3});
|
||||
|
||||
auto real_div0 = VectorRef({real_div0_var_, add0, input5_});
|
||||
auto real_div1 = VectorRef({real_div1_var_, add1, input2_});
|
||||
|
||||
auto add2 = VectorRef({prim::kPrimTensorAdd, real_div1, add2_y_});
|
||||
auto sqrt0 = VectorRef({prim_rsqrt, add2});
|
||||
auto real_div2 = VectorRef({real_div2_var_, sqrt0, real_div0});
|
||||
|
||||
return VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
|
||||
}
|
||||
|
||||
BaseRef LambNextMVRuleCond3::DefineAnotherPattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
// Two patterns share: real_div0, real_div1, add2_y_
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
|
||||
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
|
||||
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, add2_y_});
|
||||
VectorRef real_div4 = VectorRef({prim_real_div, real_div0, add4});
|
||||
return real_div4;
|
||||
}
|
||||
|
||||
const BaseRef LambNextMVRuleCond4::DefinePattern() const {
|
||||
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_rsqrt);
|
||||
|
||||
auto mul0 = VectorRef({prim::kPrimMul, mul0_x_, input4_});
|
||||
auto mul1 = VectorRef({prim::kPrimMul, mul1_sub_, input3_});
|
||||
|
@ -140,13 +247,9 @@ const BaseRef LambNextMVRuleCond4::DefinePattern() const {
|
|||
|
||||
BaseRef LambNextMVRuleCond4::DefineAnotherPattern() const {
|
||||
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_sqrt);
|
||||
const auto prim_real_div = std::make_shared<Primitive>(kRealDivOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim_real_div);
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VarPtr Ys = std::make_shared<SeqVar>();
|
||||
MS_EXCEPTION_IF_NULL(Xs);
|
||||
MS_EXCEPTION_IF_NULL(Ys);
|
||||
// Two patterns share: real_div0, real_div1, add2_y_
|
||||
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
|
||||
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
|
||||
|
|
|
@ -87,6 +87,33 @@ class LambNextMVRule : public MultipleOutputPatternProcessPass {
|
|||
VarPtr real_div2_var_;
|
||||
};
|
||||
|
||||
class LambNextMVRuleCond1 : public LambNextMVRule {
|
||||
public:
|
||||
explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {}
|
||||
|
||||
~LambNextMVRuleCond1() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
|
||||
class LambNextMVRuleCond2 : public LambNextMVRule {
|
||||
public:
|
||||
explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {}
|
||||
|
||||
~LambNextMVRuleCond2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
|
||||
class LambNextMVRuleCond3 : public LambNextMVRule {
|
||||
public:
|
||||
explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {}
|
||||
|
||||
~LambNextMVRuleCond3() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
BaseRef DefineAnotherPattern() const override;
|
||||
};
|
||||
|
||||
class LambNextMVRuleCond4 : public LambNextMVRule {
|
||||
public:
|
||||
explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {}
|
||||
|
|
|
@ -244,5 +244,125 @@ TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond4_unmatched_real_div1) {
|
|||
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond1_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond1>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond1_unmatched) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond1", "un_match");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond1>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond2_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond2>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond2_unmatched) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond2", "un_match");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond2>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond3_fusion) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "before");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond3>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "after");
|
||||
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||
}
|
||||
|
||||
TEST_F(TestHWLambNextMVRule, test_lamb_next_mv_rule_cond3_unmatched) {
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_rule_cond3", "un_match");
|
||||
std::vector<int> shp{2, 32, 224, 224};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < 13; ++i) {
|
||||
args_spec_list.push_back(x_abstract);
|
||||
}
|
||||
auto fg = GetKernelGraph(g, args_spec_list);
|
||||
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::LambNextMVRuleCond3>());
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||
|
||||
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,7 +24,6 @@ make_tuple = Primitive('make_tuple')
|
|||
tuple_getitem = Primitive('tuple_getitem')
|
||||
LambNextMV = Primitive('LambNextMV')
|
||||
|
||||
|
||||
class FnDict:
|
||||
def __init__(self):
|
||||
self.fnDict = {}
|
||||
|
@ -35,7 +34,6 @@ class FnDict:
|
|||
def __getitem__(self, name):
|
||||
return self.fnDict[name]
|
||||
|
||||
|
||||
def test_lamb_next_mv_rule_cond4(tag):
|
||||
fns = FnDict()
|
||||
|
||||
|
@ -170,3 +168,192 @@ def test_lamb_next_mv_rule_cond4(tag):
|
|||
return output
|
||||
|
||||
return fns[tag]
|
||||
|
||||
def test_lamb_next_mv_rule_cond1(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul0 = Mul(constant_mul0_x, input4)
|
||||
mul1 = Mul(constant_mul1_sub, input3)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(constant_mul2_x, input1)
|
||||
mul3 = Mul(constant_mul3_sub1, input0)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(constant_add2_y, real_div1)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
add4 = Add(constant_add2_y, sqrt1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
mul4 = Mul(constant_mul4_x, input6)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, real_div4)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6,
|
||||
constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1,
|
||||
constant_mul4_x, constant_add2_y)
|
||||
outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1),
|
||||
tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return make_tuple(output)
|
||||
|
||||
@fns
|
||||
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul0 = Mul(constant_mul0_x, input4)
|
||||
mul1 = Mul(constant_mul1_sub, input3)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(constant_mul2_x, input1)
|
||||
mul3 = Mul(constant_mul3_sub1, input0)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(constant_add2_y, real_div1)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
# un match
|
||||
add4 = Add(sqrt1, constant_add2_y)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
mul4 = Mul(constant_mul4_x, input6)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, real_div4)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
return fns[tag]
|
||||
|
||||
def test_lamb_next_mv_rule_cond2(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul0 = Mul(input4, constant_mul0_x)
|
||||
mul1 = Mul(input3, constant_mul1_sub)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(input1, constant_mul2_x)
|
||||
mul3 = Mul(constant_mul3_sub1, input0)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(constant_add2_y, real_div1)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
add4 = Add(sqrt1, constant_add2_y)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
mul4 = Mul(input6, constant_mul4_x)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, real_div4)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6,
|
||||
constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1,
|
||||
constant_mul4_x, constant_add2_y)
|
||||
outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1),
|
||||
tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return make_tuple(output)
|
||||
|
||||
@fns
|
||||
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul0 = Mul(input4, constant_mul0_x)
|
||||
mul1 = Mul(input3, constant_mul1_sub)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(input1, constant_mul2_x)
|
||||
mul3 = Mul(constant_mul3_sub1, input0)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(constant_add2_y, real_div1)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
# un match
|
||||
add4 = Add(constant_add2_y, sqrt1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
mul4 = Mul(input6, constant_mul4_x)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, real_div4)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
return fns[tag]
|
||||
|
||||
def test_lamb_next_mv_rule_cond3(tag):
|
||||
fns = FnDict()
|
||||
|
||||
@fns
|
||||
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul0 = Mul(input4, constant_mul0_x)
|
||||
mul1 = Mul(input3, constant_mul1_sub)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(input1, constant_mul2_x)
|
||||
mul3 = Mul(input0, constant_mul3_sub1)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(real_div1, constant_add2_y)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
add4 = Add(sqrt1, constant_add2_y)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
mul4 = Mul(input6, constant_mul4_x)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, real_div4)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
@fns
|
||||
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
lamb_next_mv = LambNextMV(input0, input1, input2, input3, input4, input5, input6,
|
||||
constant_mul0_x, constant_mul1_sub, constant_mul2_x, constant_mul3_sub1,
|
||||
constant_mul4_x, constant_add2_y)
|
||||
outputs = make_tuple(tuple_getitem(lamb_next_mv, 0), tuple_getitem(lamb_next_mv, 1),
|
||||
tuple_getitem(lamb_next_mv, 2), tuple_getitem(lamb_next_mv, 3))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return make_tuple(output)
|
||||
|
||||
@fns
|
||||
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
|
||||
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
|
||||
mul0 = Mul(input4, constant_mul0_x)
|
||||
mul1 = Mul(input3, constant_mul1_sub)
|
||||
add0 = Add(mul0, mul1)
|
||||
mul2 = Mul(input1, constant_mul2_x)
|
||||
mul3 = Mul(input0, constant_mul3_sub1)
|
||||
add1 = Add(mul2, mul3)
|
||||
real_div1 = RealDiv(add1, input2)
|
||||
add2 = Add(real_div1, constant_add2_y)
|
||||
sqrt0 = Rsqrt(add2)
|
||||
sqrt1 = Sqrt(real_div1)
|
||||
# un match
|
||||
add4 = Add(constant_add2_y, sqrt1)
|
||||
real_div0 = RealDiv(add0, input5)
|
||||
real_div4 = RealDiv(real_div0, add4)
|
||||
real_div2 = Mul(sqrt0, real_div0)
|
||||
mul4 = Mul(input6, constant_mul4_x)
|
||||
add3 = Add(mul4, real_div2)
|
||||
outputs = make_tuple(add3, add0, add1, real_div4)
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
||||
return fns[tag]
|
||||
|
|
Loading…
Reference in New Issue