add 3 patterns for lamb_next_mv_with_decay_rule pass

This commit is contained in:
huanghui 2020-05-21 15:56:48 +08:00
parent ca74e624e2
commit eaff850f11
7 changed files with 508 additions and 3 deletions

View File

@ -97,6 +97,9 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLRRuleFusion>());
ir_fusion_pm->AddPass(std::make_shared<ConfusionSoftmaxGradRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond1>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond2>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVWithDecayRuleCond3>());
ir_fusion_pm->AddPass(std::make_shared<LambNextMVRule>());
ir_fusion_pm->AddPass(std::make_shared<LambNextRightRule>());
ir_fusion_pm->AddPass(std::make_shared<LambUpdateWithLrV2>());

View File

@ -163,5 +163,128 @@ const AnfNodePtr LambNextMVWithDecayRule::Process(const FuncGraphPtr &func_graph
}
return nullptr;
}
const BaseRef LambNextMVWithDecayRuleCond1::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
VarPtr Zs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
MS_EXCEPTION_IF_NULL(Zs);
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
return add3;
}
const BaseRef LambNextMVWithDecayRuleCond1::DefinePattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
MS_EXCEPTION_IF_NULL(prim_sqrt);
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_deal_div);
VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]});
VectorRef mul3 = VectorRef({prim::kPrimMul, input_vars_[0], constant_mul_input_vars_[3]});
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
return add5;
}
const BaseRef LambNextMVWithDecayRuleCond2::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
VarPtr Zs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
MS_EXCEPTION_IF_NULL(Zs);
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, real_div1});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
return add3;
}
const BaseRef LambNextMVWithDecayRuleCond2::DefinePattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
MS_EXCEPTION_IF_NULL(prim_sqrt);
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_deal_div);
VectorRef mul2 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[2], input_vars_[1]});
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, constant_add2_y_, sqrt1});
VectorRef mul0 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[0], input_vars_[4]});
VectorRef mul1 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[1], input_vars_[3]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, constant_mul_input_vars_[4], input_vars_[6]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
return add5;
}
const BaseRef LambNextMVWithDecayRuleCond3::DefineAnotherPattern() const {
const auto prim_rsqrt = std::make_shared<Primitive>(kRsqrtOpName);
MS_EXCEPTION_IF_NULL(prim_rsqrt);
VarPtr Xs = std::make_shared<SeqVar>();
VarPtr Ys = std::make_shared<SeqVar>();
VarPtr Zs = std::make_shared<SeqVar>();
MS_EXCEPTION_IF_NULL(Xs);
MS_EXCEPTION_IF_NULL(Ys);
MS_EXCEPTION_IF_NULL(Zs);
VectorRef real_div0 = VectorRef({real_div0_var_, Xs});
VectorRef real_div1 = VectorRef({real_div1_var_, Ys});
VectorRef mul4 = VectorRef({mul4_var_, Zs});
VectorRef add2 = VectorRef({prim::kPrimTensorAdd, real_div1, constant_add2_y_});
VectorRef sqrt0 = VectorRef({prim_rsqrt, add2});
VectorRef real_div2 = VectorRef({prim::kPrimMul, sqrt0, real_div0});
VectorRef add3 = VectorRef({prim::kPrimTensorAdd, mul4, real_div2});
return add3;
}
const BaseRef LambNextMVWithDecayRuleCond3::DefinePattern() const {
const auto prim_sqrt = std::make_shared<Primitive>(kSqrtOpName);
MS_EXCEPTION_IF_NULL(prim_sqrt);
const auto prim_deal_div = std::make_shared<Primitive>(kRealDivOpName);
MS_EXCEPTION_IF_NULL(prim_deal_div);
VectorRef mul2 = VectorRef({prim::kPrimMul, input_vars_[1], constant_mul_input_vars_[2]});
VectorRef mul3 = VectorRef({prim::kPrimMul, constant_mul_input_vars_[3], input_vars_[0]});
VectorRef add1 = VectorRef({add1_var_, mul2, mul3});
VectorRef real_div1 = VectorRef({real_div1_var_, add1, input_vars_[2]});
VectorRef sqrt1 = VectorRef({prim_sqrt, real_div1});
VectorRef add4 = VectorRef({prim::kPrimTensorAdd, sqrt1, constant_add2_y_});
VectorRef mul0 = VectorRef({prim::kPrimMul, input_vars_[4], constant_mul_input_vars_[0]});
VectorRef mul1 = VectorRef({prim::kPrimMul, input_vars_[3], constant_mul_input_vars_[1]});
VectorRef add0 = VectorRef({add0_var_, mul0, mul1});
VectorRef real_div0 = VectorRef({real_div0_var_, add0, input_vars_[5]});
VectorRef real_div4 = VectorRef({prim_deal_div, real_div0, add4});
VectorRef mul4 = VectorRef({mul4_var_, input_vars_[6], constant_mul_input_vars_[4]});
VectorRef add5 = VectorRef({prim::kPrimTensorAdd, mul4, real_div4});
return add5;
}
} // namespace opt
} // namespace mindspore

View File

@ -74,6 +74,36 @@ class LambNextMVWithDecayRule : public PatternProcessPass {
VarPtr add0_var_;
VarPtr add1_var_;
};
class LambNextMVWithDecayRuleCond1 : public LambNextMVWithDecayRule {
public:
explicit LambNextMVWithDecayRuleCond1(bool multigraph = true)
: LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond1", multigraph) {}
~LambNextMVWithDecayRuleCond1() override = default;
const BaseRef DefinePattern() const override;
const BaseRef DefineAnotherPattern() const override;
};
class LambNextMVWithDecayRuleCond2 : public LambNextMVWithDecayRule {
public:
explicit LambNextMVWithDecayRuleCond2(bool multigraph = true)
: LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond2", multigraph) {}
~LambNextMVWithDecayRuleCond2() override = default;
const BaseRef DefinePattern() const override;
const BaseRef DefineAnotherPattern() const override;
};
class LambNextMVWithDecayRuleCond3 : public LambNextMVWithDecayRule {
public:
explicit LambNextMVWithDecayRuleCond3(bool multigraph = true)
: LambNextMVWithDecayRule("lamb_next_mv_with_decay_rule_cond3", multigraph) {}
~LambNextMVWithDecayRuleCond3() override = default;
const BaseRef DefinePattern() const override;
const BaseRef DefineAnotherPattern() const override;
};
} // namespace opt
} // namespace mindspore

View File

@ -24,6 +24,8 @@
#include "pre_activate/common/helper.h"
namespace mindspore {
namespace opt {
namespace {
bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_t *mul_index) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(add);
@ -36,6 +38,14 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimMul->name()) {
if (!opt::IsUsedByOthers(graph, cnode)) {
auto full_name = cnode->fullname_with_scope();
// exclude lamb and adam, and only work in bert
if (std::string::npos != full_name.find("adam") || std::string::npos != full_name.find("lamb") ||
std::string::npos == full_name.find("bert")) {
MS_LOG(INFO) << "Mul is in adam or lamb or not a bert network, quit fusion";
return false;
}
*mul = cnode;
*mul_index = index;
return true;
@ -45,8 +55,7 @@ bool GetMul(const FuncGraphPtr &graph, const CNodePtr &add, CNodePtr *mul, size_
}
return false;
}
namespace opt {
} // namespace
const BaseRef MulAddFusion::DefinePattern() const {
VarPtr x = std::make_shared<Var>();
VarPtr y = std::make_shared<Var>();
@ -74,7 +83,12 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
for (size_t index = 1; index < mul->size(); ++index) {
inputs.push_back(mul->input(index));
}
inputs.push_back(add->input(add->size() - mul_index));
auto another_input_node = add->input(add->size() - mul_index);
if (IsUsedByOthers(graph, another_input_node)) {
MS_LOG(INFO) << "Add's another input node is used by others, do not fuse";
return nullptr;
}
inputs.push_back(another_input_node);
auto fusion_node = graph->NewCNode(inputs);
fusion_node->set_scope(add->scope());
fusion_node->set_abstract(add->abstract());

View File

@ -253,5 +253,134 @@ TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_decay_rule_unmatched_rea
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule", "after");
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond1>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond1_un_match) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond1>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond1", "un_match");
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
DumpIR("fg.ir", fg, true);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond2>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "after");
DumpIR("g_after.ir", g_after, true);
DumpIR("new_graph.ir", new_graph, true);
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond2_un_match) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond2>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond2", "un_match");
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "before");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
DumpIR("fg.ir", fg, true);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond3>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "after");
DumpIR("g_after.ir", g_after, true);
DumpIR("new_graph.ir", new_graph, true);
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}
TEST_F(TestHWLambNextMVWithDecayRule, test_lamb_next_mv_with_decay_rule_cond3_un_match) {
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match");
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < 13; ++i) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto origin_graph = std::make_shared<session::KernelGraph>(*fg);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::LambNextMVWithDecayRuleCond3>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(fg);
EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph));
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_lamb_next_mv_with_decay_rule_cond3", "un_match");
EXPECT_FALSE(CheckEqualGraph(g_after, new_graph));
}
} // namespace opt
} // namespace mindspore

View File

@ -37,6 +37,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion1) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto scope = std::make_shared<Scope>("bert");
for (auto nd : fg->execution_order()) {
nd->set_scope(scope);
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
@ -57,6 +61,10 @@ TEST_F(TestHWMulAddFusion, test_mul_add_fusion2) {
args_spec_list.push_back(x_abstract);
}
auto fg = GetKernelGraph(g, args_spec_list);
auto scope = std::make_shared<Scope>("bert");
for (auto nd : fg->execution_order()) {
nd->set_scope(scope);
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();

View File

@ -174,3 +174,201 @@ def test_lamb_next_mv_with_decay_rule(tag):
return output
return fns[tag]
def test_lamb_next_mv_with_decay_rule_cond1(tag):
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul1 = Mul(input3, constant_mul1_sub)
mul0 = Mul(input4, constant_mul0_x)
add0 = Add(mul0, mul1)
mul2 = Mul(input1, constant_mul2_x)
mul3 = Mul(input0, constant_mul3_sub1)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(constant_add2_y, real_div1)
sqrt1 = Sqrt(real_div1)
real_div0 = RealDiv(add0, input5)
add4 = Add(sqrt1, constant_add2_y)
sqrt0 = Rsqrt(add2)
mul4 = Mul(constant_mul4_x, input6)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
add5 = Add(mul4, real_div4)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, add5)
output = tuple_getitem(outputs, 0)
return output
@fns
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6,
constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
constant_add2_y)
outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1),
tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3))
output = tuple_getitem(outputs, 0)
return make_tuple(output)
@fns
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul1 = Mul(input3, constant_mul1_sub)
mul0 = Mul(input4, constant_mul0_x)
add0 = Add(mul0, mul1)
mul2 = Mul(input1, constant_mul2_x)
mul3 = Mul(input0, constant_mul3_sub1)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(constant_add2_y, real_div1)
sqrt1 = Sqrt(real_div1)
real_div0 = RealDiv(add0, input5)
add4 = Add(sqrt1, constant_add2_y)
sqrt0 = Rsqrt(add2)
mul4 = Mul(constant_mul4_x, input6)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
add5 = Add(mul4, real_div4)
# un match
add3 = Add(real_div2, mul4)
outputs = make_tuple(add3, add0, add1, add5)
output = tuple_getitem(outputs, 0)
return output
return fns[tag]
def test_lamb_next_mv_with_decay_rule_cond2(tag):
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul1 = Mul(constant_mul1_sub, input3)
mul0 = Mul(constant_mul0_x, input4)
add0 = Add(mul0, mul1)
mul2 = Mul(constant_mul2_x, input1)
mul3 = Mul(constant_mul3_sub1, input0)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(constant_add2_y, real_div1)
sqrt1 = Sqrt(real_div1)
real_div0 = RealDiv(add0, input5)
add4 = Add(constant_add2_y, sqrt1)
sqrt0 = Rsqrt(add2)
mul4 = Mul(constant_mul4_x, input6)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
add5 = Add(mul4, real_div4)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, add5)
output = tuple_getitem(outputs, 0)
return output
@fns
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6,
constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
constant_add2_y)
outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1),
tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3))
output = tuple_getitem(outputs, 0)
return make_tuple(output)
@fns
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul1 = Mul(constant_mul1_sub, input3)
mul0 = Mul(constant_mul0_x, input4)
add0 = Add(mul0, mul1)
mul2 = Mul(constant_mul2_x, input1)
mul3 = Mul(constant_mul3_sub1, input0)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(constant_add2_y, real_div1)
sqrt1 = Sqrt(real_div1)
real_div0 = RealDiv(add0, input5)
add4 = Add(constant_add2_y, sqrt1)
sqrt0 = Rsqrt(add2)
mul4 = Mul(constant_mul4_x, input6)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
add5 = Add(mul4, real_div4)
# un_match
add3 = Add(real_div2, mul4)
outputs = make_tuple(add3, add0, add1, add5)
output = tuple_getitem(outputs, 0)
return output
return fns[tag]
def test_lamb_next_mv_with_decay_rule_cond3(tag):
fns = FnDict()
@fns
def before(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul1 = Mul(input3, constant_mul1_sub)
mul0 = Mul(input4, constant_mul0_x)
add0 = Add(mul0, mul1)
mul2 = Mul(input1, constant_mul2_x)
mul3 = Mul(constant_mul3_sub1, input0)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(real_div1, constant_add2_y)
sqrt1 = Sqrt(real_div1)
real_div0 = RealDiv(add0, input5)
add4 = Add(sqrt1, constant_add2_y)
sqrt0 = Rsqrt(add2)
mul4 = Mul(input6, constant_mul4_x)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
add5 = Add(mul4, real_div4)
add3 = Add(mul4, real_div2)
outputs = make_tuple(add3, add0, add1, add5)
output = tuple_getitem(outputs, 0)
return output
@fns
def after(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
lamb_next_mv_with_decay = LambNextMVWithDecay(input0, input1, input2, input3, input4, input5, input6,
constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x,
constant_add2_y)
outputs = make_tuple(tuple_getitem(lamb_next_mv_with_decay, 0), tuple_getitem(lamb_next_mv_with_decay, 1),
tuple_getitem(lamb_next_mv_with_decay, 2), tuple_getitem(lamb_next_mv_with_decay, 3))
output = tuple_getitem(outputs, 0)
return make_tuple(output)
@fns
def un_match(input0, input1, input2, input3, input4, input5, input6, constant_mul0_x, constant_mul1_sub,
constant_mul2_x, constant_mul3_sub1, constant_mul4_x, constant_add2_y):
mul1 = Mul(input3, constant_mul1_sub)
mul0 = Mul(input4, constant_mul0_x)
add0 = Add(mul0, mul1)
mul2 = Mul(input1, constant_mul2_x)
mul3 = Mul(constant_mul3_sub1, input0)
add1 = Add(mul2, mul3)
real_div1 = RealDiv(add1, input2)
add2 = Add(real_div1, constant_add2_y)
sqrt1 = Sqrt(real_div1)
real_div0 = RealDiv(add0, input5)
add4 = Add(sqrt1, constant_add2_y)
sqrt0 = Rsqrt(add2)
mul4 = Mul(input6, constant_mul4_x)
real_div4 = RealDiv(real_div0, add4)
real_div2 = Mul(sqrt0, real_div0)
add5 = Add(mul4, real_div4)
# un match
add3 = Add(real_div2, mul4)
outputs = make_tuple(add3, add0, add1, add5)
output = tuple_getitem(outputs, 0)
return output
return fns[tag]