diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index fb69c2cc363..ecbdde0a096 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -99,6 +99,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); @@ -114,11 +115,15 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc index 4a2387d3cc3..7dc13ee7a7d 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc @@ -41,24 +41,104 @@ std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const Equ return {NewValueNode(prim), input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y}; } -const BaseRef AdamApplyOneWithDecayRule::DefinePattern() const { +const BaseRef AdamApplyOneWithDecayRuleCond1::DefinePattern() const { auto sqrt = std::make_shared(kSqrtOpName); auto real_div = std::make_shared(kRealDivOpName); - VectorRef mul0_pattern({prim::kPrimMul, mul0_x_, input2_}); - VectorRef mul1_pattern({prim::kPrimMul, mul1_x_, input0_}); - VectorRef square0_pattern({prim::kPrimSquare, input0_}); - VectorRef add0_pattern({add0_var_, mul0_pattern, mul1_pattern}); - VectorRef mul2_pattern({prim::kPrimMul, mul2_x_, input1_}); - VectorRef mul3_pattern({prim::kPrimMul, mul3_x_, square0_pattern}); - VectorRef add1_pattern({add1_var_, mul2_pattern, mul3_pattern}); - VectorRef sqrt0_pattern({sqrt, add1_pattern}); - VectorRef add2_pattern({prim::kPrimTensorAdd, sqrt0_pattern, add2_y_}); - VectorRef mul4_pattern({prim::kPrimMul, mul4_x_, input3_}); - VectorRef real_div_pattern({real_div, add0_pattern, add2_pattern}); - VectorRef add3_pattern({prim::kPrimTensorAdd, real_div_pattern, mul4_pattern}); - VectorRef mul5_pattern({prim::kPrimMul, input4_, add3_pattern}); - VectorRef sub0_pattern({prim::kPrimSub, input3_, mul5_pattern}); - return sub0_pattern; + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, input4_, add3}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond2::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, input2_, mul0_x_}); + VectorRef mul1({prim::kPrimMul, input0_, mul1_x_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, input1_, mul2_x_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, input3_, mul4_x_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond3::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, square0, mul3_x_}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond4::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, add2_y_, sqrt0}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; +} + +const BaseRef AdamApplyOneWithDecayRuleCond5::DefinePattern() const { + auto sqrt = std::make_shared(kSqrtOpName); + auto real_div = std::make_shared(kRealDivOpName); + VectorRef mul0({prim::kPrimMul, mul0_x_, input2_}); + VectorRef mul1({prim::kPrimMul, mul1_x_, input0_}); + VectorRef square0({prim::kPrimSquare, input0_}); + VectorRef add0({add0_var_, mul0, mul1}); + VectorRef mul2({prim::kPrimMul, mul2_x_, input1_}); + VectorRef mul3({prim::kPrimMul, mul3_x_, square0}); + VectorRef add1({add1_var_, mul2, mul3}); + VectorRef sqrt0({sqrt, add1}); + VectorRef add2({prim::kPrimTensorAdd, sqrt0, add2_y_}); + VectorRef mul4({prim::kPrimMul, mul4_x_, input3_}); + VectorRef real_div0({real_div, add0, add2}); + VectorRef add3({prim::kPrimTensorAdd, mul4, real_div0}); + VectorRef mul5({prim::kPrimMul, add3, input4_}); + VectorRef sub0({prim::kPrimSub, input3_, mul5}); + return sub0; } const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h index 72c54f35352..742295dd9c4 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h @@ -18,14 +18,15 @@ #include #include +#include #include "pre_activate/common/optimizer.h" #include "utils/utils.h" namespace mindspore { namespace opt { class AdamApplyOneWithDecayRule : public PatternProcessPass { public: - explicit AdamApplyOneWithDecayRule(bool multigraph = true) - : PatternProcessPass("adam_apply_one_with_decay_rule", multigraph) { + explicit AdamApplyOneWithDecayRule(const std::string &name = "adam_apply_one_with_decay_rule", bool multigraph = true) + : PatternProcessPass(name, multigraph) { input0_ = std::make_shared(); input1_ = std::make_shared(); input2_ = std::make_shared(); @@ -41,10 +42,10 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); } ~AdamApplyOneWithDecayRule() override = default; - const BaseRef DefinePattern() const override; + const BaseRef DefinePattern() const override = 0; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - private: + protected: std::vector GetFusionNodeInputs(const EquivPtr &equiv) const; VarPtr input0_; VarPtr input1_; @@ -60,6 +61,51 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { VarPtr add0_var_; VarPtr add1_var_; }; + +class AdamApplyOneWithDecayRuleCond1 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond1(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond1", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond1() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond2 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond2(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond2", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond2() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond3 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond3(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond3", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond3() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond4 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond4(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond4", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond4() override = default; + const BaseRef DefinePattern() const override; +}; + +class AdamApplyOneWithDecayRuleCond5 : public AdamApplyOneWithDecayRule { + public: + explicit AdamApplyOneWithDecayRuleCond5(bool multigraph = true) + : AdamApplyOneWithDecayRule("adam_apply_one_with_decay_rule_cond5", multigraph) {} + + ~AdamApplyOneWithDecayRuleCond5() override = default; + const BaseRef DefinePattern() const override; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_ADAM_APPLY_ONE_WITH_DECAY_RULE_H_ diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc index 52cb0017bbc..014e60f5792 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule_test.cc @@ -30,8 +30,8 @@ class TestHWOptimizeAdamApplyOneWithDecayRule : public BackendCommon { UT::PyFuncGraphFetcher get_py_fun_; }; -TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "before"); +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond1) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "before"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -43,16 +43,16 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_r auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "after"); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond1", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } -TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_no_match) { - FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule", "no_match"); +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond2) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "before"); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); @@ -61,15 +61,78 @@ TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_no_match) { args_spec_list.push_back(x_abstract); } auto fg = GetKernelGraph(g, args_spec_list); - auto origin_graph = std::make_shared(*fg); auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); FuncGraphPtr new_graph = optimizer->Optimize(fg); - EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond2", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond3) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "before"); + + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 11; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond3", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond4) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "before"); + + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 11; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond4", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} + +TEST_F(TestHWOptimizeAdamApplyOneWithDecayRule, test_adam_apply_one_with_decay_rule_cond5) { + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "before"); + + std::vector shp{2, 32, 224, 224}; + auto x_abstract = std::make_shared(kFloat32, shp); + AbstractBasePtrList args_spec_list; + for (size_t i = 0; i < 11; ++i) { + args_spec_list.push_back(x_abstract); + } + auto fg = GetKernelGraph(g, args_spec_list); + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + pm->AddPass(std::make_shared()); + optimizer->AddPassManager(pm); + FuncGraphPtr new_graph = optimizer->Optimize(fg); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_adam_apply_one_with_decay_rule_cond5", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py index 539b91404d0..a13b1bc5834 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/adam_apply_one_with_decay_rule.py @@ -89,3 +89,168 @@ def test_adam_apply_one_with_decay_rule(tag): return make_tuple(make_tuple(item0, item1, item2)) return fns[tag] + + +def test_adam_apply_one_with_decay_rule_cond1(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(mul0_x, input2) + mul1 = mul(mul1_x, input0) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(mul2_x, input1) + mul3 = mul(mul3_x, square0) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(add2_y, sqrt0) + mul4 = mul(mul4_x, input3) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(input4, add3) + sub0 = sub(input3, mul5) + return make_tuple(add1, add0, sub0) + + @fns + def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, + add2_y) + item0 = tuple_getitem(res, 0) + item1 = tuple_getitem(res, 1) + item2 = tuple_getitem(res, 2) + return make_tuple(make_tuple(item0, item1, item2)) + + return fns[tag] + + +def test_adam_apply_one_with_decay_rule_cond2(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(input2, mul0_x) + mul1 = mul(input0, mul1_x) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(input1, mul2_x) + mul3 = mul(mul3_x, square0) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(sqrt0, add2_y) + mul4 = mul(input3, mul4_x) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(add3, input4) + sub0 = sub(input3, mul5) + return make_tuple(add1, add0, sub0) + + @fns + def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, + add2_y) + item0 = tuple_getitem(res, 0) + item1 = tuple_getitem(res, 1) + item2 = tuple_getitem(res, 2) + return make_tuple(make_tuple(item0, item1, item2)) + + return fns[tag] + + +def test_adam_apply_one_with_decay_rule_cond3(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(mul0_x, input2) + mul1 = mul(mul1_x, input0) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(mul2_x, input1) + mul3 = mul(square0, mul3_x) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(sqrt0, add2_y) + mul4 = mul(mul4_x, input3) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(add3, input4) + sub0 = sub(input3, mul5) + return make_tuple(add1, add0, sub0) + + @fns + def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, + add2_y) + item0 = tuple_getitem(res, 0) + item1 = tuple_getitem(res, 1) + item2 = tuple_getitem(res, 2) + return make_tuple(make_tuple(item0, item1, item2)) + + return fns[tag] + + +def test_adam_apply_one_with_decay_rule_cond4(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(mul0_x, input2) + mul1 = mul(mul1_x, input0) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(mul2_x, input1) + mul3 = mul(mul3_x, square0) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(add2_y, sqrt0) + mul4 = mul(mul4_x, input3) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(add3, input4) + sub0 = sub(input3, mul5) + return make_tuple(add1, add0, sub0) + + @fns + def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, + add2_y) + item0 = tuple_getitem(res, 0) + item1 = tuple_getitem(res, 1) + item2 = tuple_getitem(res, 2) + return make_tuple(make_tuple(item0, item1, item2)) + + return fns[tag] + + +def test_adam_apply_one_with_decay_rule_cond5(tag): + fns = FnDict() + + @fns + def before(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + mul0 = mul(mul0_x, input2) + mul1 = mul(mul1_x, input0) + square0 = square(input0) + add0 = add(mul0, mul1) + mul2 = mul(mul2_x, input1) + mul3 = mul(mul3_x, square0) + add1 = add(mul2, mul3) + sqrt0 = sqrt(add1) + add2 = add(sqrt0, add2_y) + mul4 = mul(mul4_x, input3) + real_div0 = real_div(add0, add2) + add3 = add(mul4, real_div0) + mul5 = mul(add3, input4) + sub0 = sub(input3, mul5) + return make_tuple(add1, add0, sub0) + + @fns + def after(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y): + res = adam_apply_one_with_decay(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, + add2_y) + item0 = tuple_getitem(res, 0) + item1 = tuple_getitem(res, 1) + item2 = tuple_getitem(res, 2) + return make_tuple(make_tuple(item0, item1, item2)) + + return fns[tag]